aboutsummaryrefslogtreecommitdiff
path: root/cmd/mimi/main.go
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-01-30 18:26:23 -0500
committerXe Iaso <me@xeiaso.net>2024-01-30 18:28:29 -0500
commitcfae1ba727e97f8b9c5a2766ffcb79a72b283f8b (patch)
tree7a74b7abd01c2ce02fc361cef05b5dea98c7c95e /cmd/mimi/main.go
parent15eb817e4ca36a6240b6beacbeff455fc7e78e3c (diff)
downloadx-cfae1ba727e97f8b9c5a2766ffcb79a72b283f8b.tar.xz
x-cfae1ba727e97f8b9c5a2766ffcb79a72b283f8b.zip
cmd/mimi: start implementing the other plan
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd/mimi/main.go')
-rw-r--r--cmd/mimi/main.go301
1 files changed, 81 insertions, 220 deletions
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go
index 46f6b3e..f0e0a46 100644
--- a/cmd/mimi/main.go
+++ b/cmd/mimi/main.go
@@ -7,32 +7,21 @@ import (
"fmt"
"log"
"log/slog"
- "net/http"
"os"
"os/signal"
- "strings"
- "sync"
"syscall"
"github.com/bwmarrin/discordgo"
- "within.website/x/cmd/mimi/ollama"
"within.website/x/internal"
- "within.website/x/llm"
- "within.website/x/llm/llamaguard"
- "within.website/x/llm/llava"
+ "within.website/x/web/ollama"
)
var (
- dataDir = flag.String("data-dir", "./var", "data directory for the bot")
- discordToken = flag.String("discord-token", "", "discord token")
- discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild")
- discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel")
- llamaguardHost = flag.String("llamaguard-host", "http://ontos:11434", "llamaguard host")
- llavaHost = flag.String("llava-host", "http://localhost:8080", "llava host")
- ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag")
- ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host")
- openAIKey = flag.String("openai-api-key", "", "openai key")
- openAITTSModel = flag.String("openai-tts-model", "nova", "openai tts model")
+ dataDir = flag.String("data-dir", "./var", "data directory for the bot")
+ discordToken = flag.String("discord-token", "", "discord token")
+ flyDiscordGuild = flag.String("fly-discord-guild", "1194719413732130866", "fly discord guild ID")
+ ollamaModel = flag.String("ollama-model", "nous-hermes2-mixtral:8x7b-dpo-q5_K_M", "ollama model tag")
+ ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast:80", "ollama host")
)
func p[T any](t T) *T {
@@ -44,263 +33,135 @@ func main() {
os.Setenv("OLLAMA_HOST", *ollamaHost)
- cli, err := ollama.ClientFromEnvironment()
- if err != nil {
- log.Fatal(err)
- }
-
- //mc := moderation.New(http.DefaultClient, *openAIKey)
-
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- if err := cli.Pull(ctx,
- &ollama.PullRequest{
- Name: *ollamaModel,
- Stream: p(true),
- },
- func(pr ollama.ProgressResponse) error {
- slog.Debug("pull progress", "progress", pr.Total-pr.Completed, "total", pr.Total)
- return nil
- },
- ); err != nil {
- log.Fatal(err)
- }
-
dg, err := discordgo.New("Bot " + *discordToken)
if err != nil {
log.Fatal(err)
}
defer dg.Close()
+ b := NewBot(dg, ollama.NewClient(*ollamaHost))
+
dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
if m.Author.ID == s.State.User.ID {
return
}
- if m.GuildID != *discordGuild {
- return
- }
-
- if m.ChannelID != *discordChannel {
- return
- }
-
if m.Author.Bot {
return
}
- if m.Content == "!mimi" {
- s.ChannelMessageSend(m.ChannelID, "mimi!")
+ if m.GuildID != *flyDiscordGuild {
return
}
- if m.Content == "!mimi clear" {
- lock.Lock()
- delete(stateMap, m.ChannelID)
- lock.Unlock()
- s.ChannelMessageSend(m.ChannelID, "mimi state cleared")
+ if m.Content == "" {
return
}
- var sb strings.Builder
- var prompt strings.Builder
-
- if ns, ok := ParseNameslash(m.Content); ok {
- if err := json.NewEncoder(&prompt).Encode(map[string]any{
- "message": ns.Message,
- "user": ns.Name,
- "is_admin": m.Author.Username == "xeiaso",
- }); err != nil {
- slog.Error("json encode error", "error", err)
- }
- } else {
- if err := json.NewEncoder(&prompt).Encode(map[string]any{
- "message": m.Content,
- "user": m.Author.Username,
- "is_admin": m.Author.Username == "xeiaso",
- }); err != nil {
- slog.Error("json encode error", "error", err)
- }
- }
-
- if len(m.Attachments) > 0 {
- for i, a := range m.Attachments {
- switch a.ContentType {
- case "image/png", "image/jpeg", "image/gif":
- default:
- continue
- }
-
- resp, err := http.Get(a.URL)
- if err != nil {
- slog.Error("http get error", "error", err)
- continue
- }
- defer resp.Body.Close()
-
- lrq, err := llava.DefaultRequest(m.Content, resp.Body)
- if err != nil {
- slog.Error("llava error", "error", err)
- continue
- }
-
- lresp, err := llava.Describe(context.Background(), *llavaHost+"/completion", lrq)
- if err != nil {
- slog.Error("llava error", "error", err)
- continue
- }
-
- if err := json.NewEncoder(&prompt).Encode(map[string]any{
- "image": i,
- "desc": lresp.Content,
- }); err != nil {
- slog.Error("json encode error", "error", err)
- continue
- }
- }
- }
-
- lock.Lock()
- defer lock.Unlock()
-
- st, ok := stateMap[m.ChannelID]
- if !ok {
- st = &State{}
- /* Messages: []llm.Message{{
- Role: "user",
- Content: prompt.String(),
- }},
- }*/
-
- stateMap[m.ChannelID] = st
+ if len(m.Mentions) == 0 {
+ return
}
- gr, err := llamaguard.Check(*llamaguardHost, st.Messages)
+ aboutFlyIO, err := b.judgeIfAboutFlyIO(ctx, m.Content)
if err != nil {
- slog.Error("llamaguard error", "error", err)
- }
-
- if !gr.Safe {
- prompt.Reset()
- prompt.WriteString("Please write a detailed message explaining that the request violates rule ")
- for _, c := range gr.Categories {
- prompt.WriteString(c)
- prompt.WriteString(": ")
- prompt.WriteString(llamaguard.Rules[c])
- }
- prompt.WriteString(".\n\nAlso explain that the conversation will be reset.")
- defer delete(stateMap, m.ChannelID)
+ slog.Error("cannot judge message", "error", err)
+ return
}
- st.Messages = append(st.Messages, llm.Message{
- Role: "user",
- Content: prompt.String(),
- })
-
- err = cli.Generate(ctx,
- &ollama.GenerateRequest{
- Model: *ollamaModel,
- Context: st.Context,
- Prompt: prompt.String(),
- Stream: p(true),
- System: "Your name is Mimi, a helpful catgirl assistant.",
- }, func(gr ollama.GenerateResponse) error {
- fmt.Fprint(&sb, gr.Response)
-
- if gr.Done {
- st.Context = gr.Context
- st.Messages = append(st.Messages, llm.Message{
- Role: "assistant",
- Content: gr.Response,
- })
-
- slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds())
- }
- return nil
- },
- )
- if err != nil {
- slog.Error("generate error", "error", err, "channel", m.ChannelID)
+ if !aboutFlyIO {
return
}
- slog.Debug("generated message", "msg", sb.String())
- gr, err = llamaguard.Check(*llamaguardHost, st.Messages)
+ resp, err := b.scoldMessage(ctx, m.Content)
if err != nil {
- slog.Error("llamaguard error", "error", err)
- s.ChannelMessageSend(m.ChannelID, "llamaguard error")
+ slog.Error("cannot fabricate scold message", "error", err)
return
}
- if !gr.Safe {
- sb.Reset()
- err = cli.Generate(ctx,
- &ollama.GenerateRequest{
- Model: *ollamaModel,
- Context: st.Context,
- Prompt: "Say that you're sorry and you can't help with that. The conversation will be reset.",
- Stream: p(true),
- System: "Your name is Mimi, a helpful catgirl assistant.",
- }, func(gr ollama.GenerateResponse) error {
- fmt.Fprint(&sb, gr.Response)
-
- if gr.Done {
- st.Context = gr.Context
- st.Messages = append(st.Messages, llm.Message{
- Role: "assistant",
- Content: gr.Response,
- })
-
- slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds())
- }
- return nil
- },
- )
- if err != nil {
- slog.Error("generate error", "error", err, "channel", m.ChannelID)
- return
- }
-
- s.ChannelMessageSend(m.ChannelID, "🔀"+sb.String())
- defer delete(stateMap, m.ChannelID)
-
+ if _, err := s.ChannelMessageSendReply(m.ChannelID, resp, m.Reference()); err != nil {
+ slog.Error("cannot send scold message", "error", err)
return
}
-
- if _, err := s.ChannelMessageSend(m.ChannelID, sb.String()); err != nil {
- slog.Error("message send error", "err", err, "message", sb.String())
- }
- slog.Debug("context length", "len", len(st.Context))
})
if err := dg.Open(); err != nil {
log.Fatal(err)
}
+ slog.Info("bot started")
+
sc := make(chan os.Signal, 1)
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
<-sc
cancel()
}
-var lock sync.Mutex
-var stateMap = map[string]*State{}
+type Bot struct {
+ dg *discordgo.Session
+ ola *ollama.Client
+}
-type State struct {
- Context []int
- Messages []llm.Message
+func NewBot(dg *discordgo.Session, ola *ollama.Client) *Bot {
+ return &Bot{
+ dg: dg,
+ ola: ola,
+ }
}
-type Nameslash struct {
- Name string `json:"name"`
- Message string `json:"message"`
+func (b *Bot) judgeIfAboutFlyIO(ctx context.Context, msg string) (bool, error) {
+ resp, err := b.ola.Chat(ctx, &ollama.CompleteRequest{
+ Model: *ollamaModel,
+ Messages: []ollama.Message{
+ {
+ Role: "system",
+ Content: "You will be given messages that may be about Fly.io or deploying apps to fly.io in programming lanugages such as Go. If a message is about Fly.io in some way, then reply with a JSON object {\"about_fly.io\": true}. If it is not, then reply {\"about_fly.io\": false}.",
+ },
+ {
+ Role: "user",
+ Content: fmt.Sprintf("Is this message about Fly.io?\n\n%s", msg),
+ },
+ },
+ Format: p("json"),
+ Stream: false,
+ })
+ if err != nil {
+ return false, fmt.Errorf("ollama: error chatting: %w", err)
+ }
+
+ type aboutFlyIO struct {
+ AboutFlyIO bool `json:"about_fly.io"`
+ }
+
+ var af aboutFlyIO
+ if err := json.Unmarshal([]byte(resp.Message.Content), &af); err != nil {
+ return false, fmt.Errorf("ollama: error unmarshaling response: %w", err)
+ }
+
+ slog.Debug("checked if about fly.io", "about_fly.io", af.AboutFlyIO, "message", msg)
+ return af.AboutFlyIO, nil
}
-func ParseNameslash(msg string) (Nameslash, bool) {
- parts := strings.Split(msg, "\\")
- if len(parts) != 2 {
- return Nameslash{}, false
+func (b *Bot) scoldMessage(ctx context.Context, content string) (string, error) {
+ resp, err := b.ola.Chat(ctx, &ollama.CompleteRequest{
+ Model: *ollamaModel,
+ Messages: []ollama.Message{
+ {
+ Role: "system",
+ Content: "Your job is to redirect questions about Fly.io to the community forums at https://community.fly.io. Don't include the link in your response, just tell the user to go there. Rephrase the question.",
+ },
+ {
+ Role: "user",
+ Content: fmt.Sprintf("Please redirect this question to the community forums:\n\n%s", content),
+ },
+ },
+ Stream: false,
+ })
+ if err != nil {
+ return "", fmt.Errorf("ollama: error chatting: %w", err)
}
- return Nameslash{parts[0], parts[1]}, true
+
+ return resp.Message.Content, nil
}