diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-01-30 18:26:23 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-01-30 18:28:29 -0500 |
| commit | cfae1ba727e97f8b9c5a2766ffcb79a72b283f8b (patch) | |
| tree | 7a74b7abd01c2ce02fc361cef05b5dea98c7c95e /cmd/mimi/main.go | |
| parent | 15eb817e4ca36a6240b6beacbeff455fc7e78e3c (diff) | |
| download | x-cfae1ba727e97f8b9c5a2766ffcb79a72b283f8b.tar.xz x-cfae1ba727e97f8b9c5a2766ffcb79a72b283f8b.zip | |
cmd/mimi: start implementing the other plan
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd/mimi/main.go')
| -rw-r--r-- | cmd/mimi/main.go | 301 |
1 files changed, 81 insertions, 220 deletions
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go index 46f6b3e..f0e0a46 100644 --- a/cmd/mimi/main.go +++ b/cmd/mimi/main.go @@ -7,32 +7,21 @@ import ( "fmt" "log" "log/slog" - "net/http" "os" "os/signal" - "strings" - "sync" "syscall" "github.com/bwmarrin/discordgo" - "within.website/x/cmd/mimi/ollama" "within.website/x/internal" - "within.website/x/llm" - "within.website/x/llm/llamaguard" - "within.website/x/llm/llava" + "within.website/x/web/ollama" ) var ( - dataDir = flag.String("data-dir", "./var", "data directory for the bot") - discordToken = flag.String("discord-token", "", "discord token") - discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild") - discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel") - llamaguardHost = flag.String("llamaguard-host", "http://ontos:11434", "llamaguard host") - llavaHost = flag.String("llava-host", "http://localhost:8080", "llava host") - ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag") - ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host") - openAIKey = flag.String("openai-api-key", "", "openai key") - openAITTSModel = flag.String("openai-tts-model", "nova", "openai tts model") + dataDir = flag.String("data-dir", "./var", "data directory for the bot") + discordToken = flag.String("discord-token", "", "discord token") + flyDiscordGuild = flag.String("fly-discord-guild", "1194719413732130866", "fly discord guild ID") + ollamaModel = flag.String("ollama-model", "nous-hermes2-mixtral:8x7b-dpo-q5_K_M", "ollama model tag") + ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast:80", "ollama host") ) func p[T any](t T) *T { @@ -44,263 +33,135 @@ func main() { os.Setenv("OLLAMA_HOST", *ollamaHost) - cli, err := ollama.ClientFromEnvironment() - if err != nil { - log.Fatal(err) - } - - //mc := moderation.New(http.DefaultClient, *openAIKey) - ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := cli.Pull(ctx, - &ollama.PullRequest{ - Name: *ollamaModel, - Stream: p(true), - }, - func(pr ollama.ProgressResponse) error { - slog.Debug("pull progress", "progress", pr.Total-pr.Completed, "total", pr.Total) - return nil - }, - ); err != nil { - log.Fatal(err) - } - dg, err := discordgo.New("Bot " + *discordToken) if err != nil { log.Fatal(err) } defer dg.Close() + b := NewBot(dg, ollama.NewClient(*ollamaHost)) + dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { if m.Author.ID == s.State.User.ID { return } - if m.GuildID != *discordGuild { - return - } - - if m.ChannelID != *discordChannel { - return - } - if m.Author.Bot { return } - if m.Content == "!mimi" { - s.ChannelMessageSend(m.ChannelID, "mimi!") + if m.GuildID != *flyDiscordGuild { return } - if m.Content == "!mimi clear" { - lock.Lock() - delete(stateMap, m.ChannelID) - lock.Unlock() - s.ChannelMessageSend(m.ChannelID, "mimi state cleared") + if m.Content == "" { return } - var sb strings.Builder - var prompt strings.Builder - - if ns, ok := ParseNameslash(m.Content); ok { - if err := json.NewEncoder(&prompt).Encode(map[string]any{ - "message": ns.Message, - "user": ns.Name, - "is_admin": m.Author.Username == "xeiaso", - }); err != nil { - slog.Error("json encode error", "error", err) - } - } else { - if err := json.NewEncoder(&prompt).Encode(map[string]any{ - "message": m.Content, - "user": m.Author.Username, - "is_admin": m.Author.Username == "xeiaso", - }); err != nil { - slog.Error("json encode error", "error", err) - } - } - - if len(m.Attachments) > 0 { - for i, a := range m.Attachments { - switch a.ContentType { - case "image/png", "image/jpeg", "image/gif": - default: - continue - } - - resp, err := http.Get(a.URL) - if err != nil { - slog.Error("http get error", "error", err) - continue - } - defer resp.Body.Close() - - lrq, err := llava.DefaultRequest(m.Content, resp.Body) - if err != nil { - slog.Error("llava error", "error", err) - continue - } - - lresp, err := llava.Describe(context.Background(), *llavaHost+"/completion", lrq) - if err != nil { - slog.Error("llava error", "error", err) - continue - } - - if err := json.NewEncoder(&prompt).Encode(map[string]any{ - "image": i, - "desc": lresp.Content, - }); err != nil { - slog.Error("json encode error", "error", err) - continue - } - } - } - - lock.Lock() - defer lock.Unlock() - - st, ok := stateMap[m.ChannelID] - if !ok { - st = &State{} - /* Messages: []llm.Message{{ - Role: "user", - Content: prompt.String(), - }}, - }*/ - - stateMap[m.ChannelID] = st + if len(m.Mentions) == 0 { + return } - gr, err := llamaguard.Check(*llamaguardHost, st.Messages) + aboutFlyIO, err := b.judgeIfAboutFlyIO(ctx, m.Content) if err != nil { - slog.Error("llamaguard error", "error", err) - } - - if !gr.Safe { - prompt.Reset() - prompt.WriteString("Please write a detailed message explaining that the request violates rule ") - for _, c := range gr.Categories { - prompt.WriteString(c) - prompt.WriteString(": ") - prompt.WriteString(llamaguard.Rules[c]) - } - prompt.WriteString(".\n\nAlso explain that the conversation will be reset.") - defer delete(stateMap, m.ChannelID) + slog.Error("cannot judge message", "error", err) + return } - st.Messages = append(st.Messages, llm.Message{ - Role: "user", - Content: prompt.String(), - }) - - err = cli.Generate(ctx, - &ollama.GenerateRequest{ - Model: *ollamaModel, - Context: st.Context, - Prompt: prompt.String(), - Stream: p(true), - System: "Your name is Mimi, a helpful catgirl assistant.", - }, func(gr ollama.GenerateResponse) error { - fmt.Fprint(&sb, gr.Response) - - if gr.Done { - st.Context = gr.Context - st.Messages = append(st.Messages, llm.Message{ - Role: "assistant", - Content: gr.Response, - }) - - slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds()) - } - return nil - }, - ) - if err != nil { - slog.Error("generate error", "error", err, "channel", m.ChannelID) + if !aboutFlyIO { return } - slog.Debug("generated message", "msg", sb.String()) - gr, err = llamaguard.Check(*llamaguardHost, st.Messages) + resp, err := b.scoldMessage(ctx, m.Content) if err != nil { - slog.Error("llamaguard error", "error", err) - s.ChannelMessageSend(m.ChannelID, "llamaguard error") + slog.Error("cannot fabricate scold message", "error", err) return } - if !gr.Safe { - sb.Reset() - err = cli.Generate(ctx, - &ollama.GenerateRequest{ - Model: *ollamaModel, - Context: st.Context, - Prompt: "Say that you're sorry and you can't help with that. The conversation will be reset.", - Stream: p(true), - System: "Your name is Mimi, a helpful catgirl assistant.", - }, func(gr ollama.GenerateResponse) error { - fmt.Fprint(&sb, gr.Response) - - if gr.Done { - st.Context = gr.Context - st.Messages = append(st.Messages, llm.Message{ - Role: "assistant", - Content: gr.Response, - }) - - slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds()) - } - return nil - }, - ) - if err != nil { - slog.Error("generate error", "error", err, "channel", m.ChannelID) - return - } - - s.ChannelMessageSend(m.ChannelID, "🔀"+sb.String()) - defer delete(stateMap, m.ChannelID) - + if _, err := s.ChannelMessageSendReply(m.ChannelID, resp, m.Reference()); err != nil { + slog.Error("cannot send scold message", "error", err) return } - - if _, err := s.ChannelMessageSend(m.ChannelID, sb.String()); err != nil { - slog.Error("message send error", "err", err, "message", sb.String()) - } - slog.Debug("context length", "len", len(st.Context)) }) if err := dg.Open(); err != nil { log.Fatal(err) } + slog.Info("bot started") + sc := make(chan os.Signal, 1) signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) <-sc cancel() } -var lock sync.Mutex -var stateMap = map[string]*State{} +type Bot struct { + dg *discordgo.Session + ola *ollama.Client +} -type State struct { - Context []int - Messages []llm.Message +func NewBot(dg *discordgo.Session, ola *ollama.Client) *Bot { + return &Bot{ + dg: dg, + ola: ola, + } } -type Nameslash struct { - Name string `json:"name"` - Message string `json:"message"` +func (b *Bot) judgeIfAboutFlyIO(ctx context.Context, msg string) (bool, error) { + resp, err := b.ola.Chat(ctx, &ollama.CompleteRequest{ + Model: *ollamaModel, + Messages: []ollama.Message{ + { + Role: "system", + Content: "You will be given messages that may be about Fly.io or deploying apps to fly.io in programming lanugages such as Go. If a message is about Fly.io in some way, then reply with a JSON object {\"about_fly.io\": true}. If it is not, then reply {\"about_fly.io\": false}.", + }, + { + Role: "user", + Content: fmt.Sprintf("Is this message about Fly.io?\n\n%s", msg), + }, + }, + Format: p("json"), + Stream: false, + }) + if err != nil { + return false, fmt.Errorf("ollama: error chatting: %w", err) + } + + type aboutFlyIO struct { + AboutFlyIO bool `json:"about_fly.io"` + } + + var af aboutFlyIO + if err := json.Unmarshal([]byte(resp.Message.Content), &af); err != nil { + return false, fmt.Errorf("ollama: error unmarshaling response: %w", err) + } + + slog.Debug("checked if about fly.io", "about_fly.io", af.AboutFlyIO, "message", msg) + return af.AboutFlyIO, nil } -func ParseNameslash(msg string) (Nameslash, bool) { - parts := strings.Split(msg, "\\") - if len(parts) != 2 { - return Nameslash{}, false +func (b *Bot) scoldMessage(ctx context.Context, content string) (string, error) { + resp, err := b.ola.Chat(ctx, &ollama.CompleteRequest{ + Model: *ollamaModel, + Messages: []ollama.Message{ + { + Role: "system", + Content: "Your job is to redirect questions about Fly.io to the community forums at https://community.fly.io. Don't include the link in your response, just tell the user to go there. Rephrase the question.", + }, + { + Role: "user", + Content: fmt.Sprintf("Please redirect this question to the community forums:\n\n%s", content), + }, + }, + Stream: false, + }) + if err != nil { + return "", fmt.Errorf("ollama: error chatting: %w", err) } - return Nameslash{parts[0], parts[1]}, true + + return resp.Message.Content, nil } |
