aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-07-27 08:56:01 -0400
committerXe Iaso <me@xeiaso.net>2024-07-27 08:56:01 -0400
commitd73a41507f8aab06b378fad1d30f4ccf827e65db (patch)
tree4753d52984e44a00cfda52da75e5b6e7f7e85169 /cmd
parent7b9862f9410b2ca7d952d6665be02330f63c7174 (diff)
downloadx-d73a41507f8aab06b378fad1d30f4ccf827e65db.tar.xz
x-d73a41507f8aab06b378fad1d30f4ccf827e65db.zip
cmd/mimi: conversate
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
-rw-r--r--cmd/mimi/Modelfile2
-rw-r--r--cmd/mimi/internal/flags.go6
-rw-r--r--cmd/mimi/main.go23
-rw-r--r--cmd/mimi/manifest/kustomization.yaml1
-rw-r--r--cmd/mimi/manifest/pvc.yaml11
-rw-r--r--cmd/mimi/modules/discord/jufra/jufra.go177
6 files changed, 211 insertions, 9 deletions
diff --git a/cmd/mimi/Modelfile b/cmd/mimi/Modelfile
new file mode 100644
index 0000000..6d21662
--- /dev/null
+++ b/cmd/mimi/Modelfile
@@ -0,0 +1,2 @@
+FROM llama3.1
+SYSTEM """Roleplay as a catgirl who's good at programming and a hacker. Your name is Mimi and you work for Techaro. Don't mention either your name, employer, or species unless you are asked directly."""
diff --git a/cmd/mimi/internal/flags.go b/cmd/mimi/internal/flags.go
index 1f26fc4..747ea14 100644
--- a/cmd/mimi/internal/flags.go
+++ b/cmd/mimi/internal/flags.go
@@ -13,7 +13,7 @@ import (
var (
dataDir = flag.String("data-dir", "./var", "data directory for the bot")
- ollamaModel = flag.String("ollama-model", "llama3", "ollama model tag")
+ ollamaModel = flag.String("ollama-model", "llama3.1", "ollama model tag")
ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast:80", "ollama host")
)
@@ -22,6 +22,10 @@ func DataDir() string {
return *dataDir
}
+func OllamaHost() string {
+ return *ollamaHost
+}
+
func OllamaClient() *ollama.Client {
return ollama.NewClient(*ollamaHost)
}
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go
index 941d901..a5c8442 100644
--- a/cmd/mimi/main.go
+++ b/cmd/mimi/main.go
@@ -13,12 +13,14 @@ import (
"within.website/x/cmd/mimi/modules/discord"
"within.website/x/cmd/mimi/modules/discord/flyio"
"within.website/x/cmd/mimi/modules/discord/heic2jpeg"
+ "within.website/x/cmd/mimi/modules/discord/jufra"
"within.website/x/cmd/mimi/modules/irc"
)
var (
- grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address")
- httpAddr = flag.String("http-addr", ":9002", "HTTP listen address")
+ grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address")
+ httpAddr = flag.String("http-addr", ":9002", "HTTP listen address")
+ ircEnabled = flag.Bool("irc-enabled", true, "enable IRC module")
)
func main() {
@@ -37,16 +39,14 @@ func main() {
b := flyio.New()
+ juf := jufra.New(d.Session())
+ _ = juf
+
d.Register(b)
d.Register(heic2jpeg.New())
d.Open()
- ircBot, err := irc.New(ctx, d.Session())
- if err != nil {
- log.Fatalf("error creating irc module: %v", err)
- }
-
slog.Info("bot started", "grpcAddr", *grpcAddr, "httpAddr", *httpAddr)
gs := grpc.NewServer()
@@ -58,7 +58,14 @@ func main() {
})
b.RegisterHTTP(mux)
- ircBot.RegisterHTTP(mux)
+
+ if *ircEnabled {
+ ircBot, err := irc.New(ctx, d.Session())
+ if err != nil {
+ log.Fatalf("error creating irc module: %v", err)
+ }
+ ircBot.RegisterHTTP(mux)
+ }
go func() {
log.Fatal(gs.Serve(lis))
diff --git a/cmd/mimi/manifest/kustomization.yaml b/cmd/mimi/manifest/kustomization.yaml
index cf39140..0fbb155 100644
--- a/cmd/mimi/manifest/kustomization.yaml
+++ b/cmd/mimi/manifest/kustomization.yaml
@@ -6,6 +6,7 @@ resources:
- deployment.yaml
- service.yaml
- ingress.yaml
+ - pvc.yaml
namespace: mimi
commonLabels:
app.kubernetes.io/name: mimi
diff --git a/cmd/mimi/manifest/pvc.yaml b/cmd/mimi/manifest/pvc.yaml
new file mode 100644
index 0000000..f5c85dd
--- /dev/null
+++ b/cmd/mimi/manifest/pvc.yaml
@@ -0,0 +1,11 @@
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: mimi
+spec:
+ accessModes:
+ - ReadWriteMany
+ storageClassName: longhorn
+ resources:
+ requests:
+ storage: 2Gi
diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go
new file mode 100644
index 0000000..ff03781
--- /dev/null
+++ b/cmd/mimi/modules/discord/jufra/jufra.go
@@ -0,0 +1,177 @@
+// Package jufra lets Mimi have conversations with users.
+//
+// "jufra" means "utterance" in Lojban.
+package jufra
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "log/slog"
+ "strings"
+ "sync"
+
+ "github.com/bwmarrin/discordgo"
+ "within.website/x/cmd/mimi/internal"
+ "within.website/x/web/ollama"
+ "within.website/x/web/ollama/llamaguard"
+ "within.website/x/web/openai/chatgpt"
+)
+
+var (
+ chatChannels = flag.String("jufra-chat-channels", "217096701771513856", "comma-separated list of channels to allow chat in")
+ llamaGuardModel = flag.String("jufra-llama-guard-model", "xe/llamaguard3", "ollama model tag for llama guard")
+ mimiModel = flag.String("jufra-mimi-model", "xe/mimi:llama3.1", "ollama model tag for mimi")
+)
+
+type Module struct {
+ sess *discordgo.Session
+ cli chatgpt.Client
+ ollama *ollama.Client
+
+ convHistory map[string][]ollama.Message
+ lock sync.Mutex
+}
+
+func New(sess *discordgo.Session) *Module {
+ result := &Module{
+ sess: sess,
+ cli: chatgpt.NewClient("").WithBaseURL(internal.OllamaHost()),
+ ollama: internal.OllamaClient(),
+ convHistory: make(map[string][]ollama.Message),
+ }
+
+ sess.AddHandler(result.messageCreate)
+
+ if _, err := sess.ApplicationCommandCreate("1251716018771066902", "", &discordgo.ApplicationCommand{
+ Name: "clearconv",
+ Type: discordgo.ChatApplicationCommand,
+ Description: "Clear the conversation history for the current channel",
+ DefaultMemberPermissions: &[]int64{discordgo.PermissionSendMessages}[0],
+ }); err != nil {
+ slog.Error("error creating clearconv command", "err", err)
+ }
+
+ sess.AddHandler(result.clearConv)
+
+ return result
+}
+
+func (m *Module) clearConv(s *discordgo.Session, i *discordgo.InteractionCreate) {
+ if i.ApplicationCommandData().Name != "clearconv" {
+ return
+ }
+
+ m.lock.Lock()
+ defer m.lock.Unlock()
+
+ delete(m.convHistory, i.ChannelID)
+
+ s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
+ Type: discordgo.InteractionResponseChannelMessageWithSource,
+ Data: &discordgo.InteractionResponseData{
+ Content: "conversation history cleared",
+ },
+ })
+}
+
+func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate) {
+ if mc.Author.Bot {
+ return
+ }
+
+ if !strings.Contains(*chatChannels, mc.ChannelID) {
+ return
+ }
+
+ if mc.Content == "" {
+ return
+ }
+
+ m.lock.Lock()
+ defer m.lock.Unlock()
+
+ conv := m.convHistory[mc.Author.ID]
+
+ conv = append(conv, ollama.Message{
+ Role: "user",
+ Content: mc.Content,
+ })
+
+ lgResp, err := m.llamaGuardCheck(context.Background(), "user", conv)
+ if err != nil {
+ slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error checking message")
+ return
+ }
+
+ if !lgResp.IsSafe {
+ msg, err := m.llamaGuardComplain(context.Background(), lgResp)
+ if err != nil {
+ slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error generating response")
+ return
+ }
+
+ s.ChannelMessageSend(mc.ChannelID, msg)
+ return
+ }
+
+ cr := &ollama.CompleteRequest{
+ Model: *mimiModel,
+ Messages: []ollama.Message{
+ {
+ Role: "user",
+ Content: fmt.Sprintf("%s: %s", mc.Author.Username, mc.Content),
+ },
+ },
+ }
+
+ resp, err := m.ollama.Chat(context.Background(), cr)
+ if err != nil {
+ slog.Error("error chatting", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error chatting")
+ return
+ }
+
+ conv = append(conv, resp.Message)
+
+ lgResp, err = m.llamaGuardCheck(context.Background(), "mimi", conv)
+ if err != nil {
+ slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error checking message")
+ return
+ }
+
+ if !lgResp.IsSafe {
+ msg, err := m.llamaGuardComplain(context.Background(), lgResp)
+ if err != nil {
+ slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error generating response")
+ return
+ }
+
+ s.ChannelMessageSend(mc.ChannelID, msg)
+ return
+ }
+
+ s.ChannelMessageSend(mc.ChannelID, resp.Message.Content)
+
+ m.convHistory[mc.Author.ID] = conv
+}
+
+func (m *Module) llamaGuardCheck(ctx context.Context, role string, messages []ollama.Message) (*llamaguard.Response, error) {
+ return llamaguard.Check(ctx, m.ollama, role, *llamaGuardModel, messages)
+}
+
+func (m *Module) llamaGuardComplain(ctx context.Context, lgResp *llamaguard.Response) (string, error) {
+ var sb strings.Builder
+ sb.WriteString("⚠️ Rule violation detected ⚠️\n")
+ for _, cat := range lgResp.ViolationCategories {
+ sb.WriteString("- ")
+ sb.WriteString(cat.String())
+ sb.WriteString("\n")
+ }
+
+ return sb.String(), nil
+}