diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-07-27 08:56:01 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-07-27 08:56:01 -0400 |
| commit | d73a41507f8aab06b378fad1d30f4ccf827e65db (patch) | |
| tree | 4753d52984e44a00cfda52da75e5b6e7f7e85169 /cmd | |
| parent | 7b9862f9410b2ca7d952d6665be02330f63c7174 (diff) | |
| download | x-d73a41507f8aab06b378fad1d30f4ccf827e65db.tar.xz x-d73a41507f8aab06b378fad1d30f4ccf827e65db.zip | |
cmd/mimi: conversate
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/mimi/Modelfile | 2 | ||||
| -rw-r--r-- | cmd/mimi/internal/flags.go | 6 | ||||
| -rw-r--r-- | cmd/mimi/main.go | 23 | ||||
| -rw-r--r-- | cmd/mimi/manifest/kustomization.yaml | 1 | ||||
| -rw-r--r-- | cmd/mimi/manifest/pvc.yaml | 11 | ||||
| -rw-r--r-- | cmd/mimi/modules/discord/jufra/jufra.go | 177 |
6 files changed, 211 insertions, 9 deletions
diff --git a/cmd/mimi/Modelfile b/cmd/mimi/Modelfile new file mode 100644 index 0000000..6d21662 --- /dev/null +++ b/cmd/mimi/Modelfile @@ -0,0 +1,2 @@ +FROM llama3.1 +SYSTEM """Roleplay as a catgirl who's good at programming and a hacker. Your name is Mimi and you work for Techaro. Don't mention either your name, employer, or species unless you are asked directly.""" diff --git a/cmd/mimi/internal/flags.go b/cmd/mimi/internal/flags.go index 1f26fc4..747ea14 100644 --- a/cmd/mimi/internal/flags.go +++ b/cmd/mimi/internal/flags.go @@ -13,7 +13,7 @@ import ( var ( dataDir = flag.String("data-dir", "./var", "data directory for the bot") - ollamaModel = flag.String("ollama-model", "llama3", "ollama model tag") + ollamaModel = flag.String("ollama-model", "llama3.1", "ollama model tag") ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast:80", "ollama host") ) @@ -22,6 +22,10 @@ func DataDir() string { return *dataDir } +func OllamaHost() string { + return *ollamaHost +} + func OllamaClient() *ollama.Client { return ollama.NewClient(*ollamaHost) } diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go index 941d901..a5c8442 100644 --- a/cmd/mimi/main.go +++ b/cmd/mimi/main.go @@ -13,12 +13,14 @@ import ( "within.website/x/cmd/mimi/modules/discord" "within.website/x/cmd/mimi/modules/discord/flyio" "within.website/x/cmd/mimi/modules/discord/heic2jpeg" + "within.website/x/cmd/mimi/modules/discord/jufra" "within.website/x/cmd/mimi/modules/irc" ) var ( - grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address") - httpAddr = flag.String("http-addr", ":9002", "HTTP listen address") + grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address") + httpAddr = flag.String("http-addr", ":9002", "HTTP listen address") + ircEnabled = flag.Bool("irc-enabled", true, "enable IRC module") ) func main() { @@ -37,16 +39,14 @@ func main() { b := flyio.New() + juf := jufra.New(d.Session()) + _ = juf + d.Register(b) d.Register(heic2jpeg.New()) d.Open() - ircBot, err := irc.New(ctx, d.Session()) - if err != nil { - log.Fatalf("error creating irc module: %v", err) - } - slog.Info("bot started", "grpcAddr", *grpcAddr, "httpAddr", *httpAddr) gs := grpc.NewServer() @@ -58,7 +58,14 @@ func main() { }) b.RegisterHTTP(mux) - ircBot.RegisterHTTP(mux) + + if *ircEnabled { + ircBot, err := irc.New(ctx, d.Session()) + if err != nil { + log.Fatalf("error creating irc module: %v", err) + } + ircBot.RegisterHTTP(mux) + } go func() { log.Fatal(gs.Serve(lis)) diff --git a/cmd/mimi/manifest/kustomization.yaml b/cmd/mimi/manifest/kustomization.yaml index cf39140..0fbb155 100644 --- a/cmd/mimi/manifest/kustomization.yaml +++ b/cmd/mimi/manifest/kustomization.yaml @@ -6,6 +6,7 @@ resources: - deployment.yaml - service.yaml - ingress.yaml + - pvc.yaml namespace: mimi commonLabels: app.kubernetes.io/name: mimi diff --git a/cmd/mimi/manifest/pvc.yaml b/cmd/mimi/manifest/pvc.yaml new file mode 100644 index 0000000..f5c85dd --- /dev/null +++ b/cmd/mimi/manifest/pvc.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: mimi +spec: + accessModes: + - ReadWriteMany + storageClassName: longhorn + resources: + requests: + storage: 2Gi diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go new file mode 100644 index 0000000..ff03781 --- /dev/null +++ b/cmd/mimi/modules/discord/jufra/jufra.go @@ -0,0 +1,177 @@ +// Package jufra lets Mimi have conversations with users. +// +// "jufra" means "utterance" in Lojban. +package jufra + +import ( + "context" + "flag" + "fmt" + "log/slog" + "strings" + "sync" + + "github.com/bwmarrin/discordgo" + "within.website/x/cmd/mimi/internal" + "within.website/x/web/ollama" + "within.website/x/web/ollama/llamaguard" + "within.website/x/web/openai/chatgpt" +) + +var ( + chatChannels = flag.String("jufra-chat-channels", "217096701771513856", "comma-separated list of channels to allow chat in") + llamaGuardModel = flag.String("jufra-llama-guard-model", "xe/llamaguard3", "ollama model tag for llama guard") + mimiModel = flag.String("jufra-mimi-model", "xe/mimi:llama3.1", "ollama model tag for mimi") +) + +type Module struct { + sess *discordgo.Session + cli chatgpt.Client + ollama *ollama.Client + + convHistory map[string][]ollama.Message + lock sync.Mutex +} + +func New(sess *discordgo.Session) *Module { + result := &Module{ + sess: sess, + cli: chatgpt.NewClient("").WithBaseURL(internal.OllamaHost()), + ollama: internal.OllamaClient(), + convHistory: make(map[string][]ollama.Message), + } + + sess.AddHandler(result.messageCreate) + + if _, err := sess.ApplicationCommandCreate("1251716018771066902", "", &discordgo.ApplicationCommand{ + Name: "clearconv", + Type: discordgo.ChatApplicationCommand, + Description: "Clear the conversation history for the current channel", + DefaultMemberPermissions: &[]int64{discordgo.PermissionSendMessages}[0], + }); err != nil { + slog.Error("error creating clearconv command", "err", err) + } + + sess.AddHandler(result.clearConv) + + return result +} + +func (m *Module) clearConv(s *discordgo.Session, i *discordgo.InteractionCreate) { + if i.ApplicationCommandData().Name != "clearconv" { + return + } + + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.convHistory, i.ChannelID) + + s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ + Type: discordgo.InteractionResponseChannelMessageWithSource, + Data: &discordgo.InteractionResponseData{ + Content: "conversation history cleared", + }, + }) +} + +func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate) { + if mc.Author.Bot { + return + } + + if !strings.Contains(*chatChannels, mc.ChannelID) { + return + } + + if mc.Content == "" { + return + } + + m.lock.Lock() + defer m.lock.Unlock() + + conv := m.convHistory[mc.Author.ID] + + conv = append(conv, ollama.Message{ + Role: "user", + Content: mc.Content, + }) + + lgResp, err := m.llamaGuardCheck(context.Background(), "user", conv) + if err != nil { + slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error checking message") + return + } + + if !lgResp.IsSafe { + msg, err := m.llamaGuardComplain(context.Background(), lgResp) + if err != nil { + slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error generating response") + return + } + + s.ChannelMessageSend(mc.ChannelID, msg) + return + } + + cr := &ollama.CompleteRequest{ + Model: *mimiModel, + Messages: []ollama.Message{ + { + Role: "user", + Content: fmt.Sprintf("%s: %s", mc.Author.Username, mc.Content), + }, + }, + } + + resp, err := m.ollama.Chat(context.Background(), cr) + if err != nil { + slog.Error("error chatting", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error chatting") + return + } + + conv = append(conv, resp.Message) + + lgResp, err = m.llamaGuardCheck(context.Background(), "mimi", conv) + if err != nil { + slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error checking message") + return + } + + if !lgResp.IsSafe { + msg, err := m.llamaGuardComplain(context.Background(), lgResp) + if err != nil { + slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error generating response") + return + } + + s.ChannelMessageSend(mc.ChannelID, msg) + return + } + + s.ChannelMessageSend(mc.ChannelID, resp.Message.Content) + + m.convHistory[mc.Author.ID] = conv +} + +func (m *Module) llamaGuardCheck(ctx context.Context, role string, messages []ollama.Message) (*llamaguard.Response, error) { + return llamaguard.Check(ctx, m.ollama, role, *llamaGuardModel, messages) +} + +func (m *Module) llamaGuardComplain(ctx context.Context, lgResp *llamaguard.Response) (string, error) { + var sb strings.Builder + sb.WriteString("⚠️ Rule violation detected ⚠️\n") + for _, cat := range lgResp.ViolationCategories { + sb.WriteString("- ") + sb.WriteString(cat.String()) + sb.WriteString("\n") + } + + return sb.String(), nil +} |
