diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-12-08 22:47:08 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-12-08 22:47:08 -0500 |
| commit | 1403d6130d772b05fdd8807f9aa09474e01051eb (patch) | |
| tree | a9cb026ef236068732ad1e2df9adc068ebb654e3 /cmd | |
| parent | bfd9d18254891113b8ecd2d5fa86a32c0744711f (diff) | |
| download | x-1403d6130d772b05fdd8807f9aa09474e01051eb.tar.xz x-1403d6130d772b05fdd8807f9aa09474e01051eb.zip | |
cmd/mimi: llamaguard filtering
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/llamaguard/main.go | 58 | ||||
| -rw-r--r-- | cmd/mimi/main.go | 83 |
2 files changed, 132 insertions, 9 deletions
diff --git a/cmd/llamaguard/main.go b/cmd/llamaguard/main.go new file mode 100644 index 0000000..0603cf4 --- /dev/null +++ b/cmd/llamaguard/main.go @@ -0,0 +1,58 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log/slog" + "os" + "strings" + + "within.website/x/cmd/mimi/ollama" + "within.website/x/internal" + "within.website/x/llm" + "within.website/x/llm/llamaguard" +) + +var ( + model = flag.String("model", "xe/llamaguard", "model to use") +) + +func main() { + internal.HandleStartup() + + var messages []llm.Message + + if err := json.NewDecoder(os.Stdin).Decode(&messages); err != nil { + panic(err) + } + + slog.Info("got messages", "num", len(messages)) + + out, err := llamaguard.Prompt(messages) + if err != nil { + panic(err) + } + + fmt.Println(out) + + oc, err := ollama.ClientFromEnvironment() + if err != nil { + panic(err) + } + + var result strings.Builder + if err := oc.Generate(context.Background(), &ollama.GenerateRequest{ + Model: *model, + Prompt: out, + Raw: true, + }, func(gr ollama.GenerateResponse) error { + result.WriteString(gr.Response) + return nil + }); err != nil { + panic(err) + } + + fmt.Println(strings.TrimSpace(result.String())) +} diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go index fc72fa7..2b3d640 100644 --- a/cmd/mimi/main.go +++ b/cmd/mimi/main.go @@ -17,6 +17,7 @@ import ( "within.website/x/cmd/mimi/ollama" "within.website/x/internal" "within.website/x/llm" + "within.website/x/llm/llamaguard" ) var ( @@ -24,6 +25,7 @@ var ( discordToken = flag.String("discord-token", "", "discord token") discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild") discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel") + llamaguardHost = flag.String("llamaguard-host", "http://ontos:11434", "llamaguard host") ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag") ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host") openAIKey = flag.String("openai-api-key", "", "openai key") @@ -122,17 +124,37 @@ func main() { st, ok := stateMap[m.ChannelID] if !ok { - st = &State{ - Messages: []llm.Message{{ - Role: "user", - Content: prompt.String(), - }}, - } + st = &State{} + /* Messages: []llm.Message{{ + Role: "user", + Content: prompt.String(), + }}, + }*/ stateMap[m.ChannelID] = st } - fmt.Println(Prompt(st.Messages)) + gr, err := llamaguard.Check(*llamaguardHost, st.Messages) + if err != nil { + slog.Error("llamaguard error", "error", err) + } + + if !gr.Safe { + prompt.Reset() + prompt.WriteString("Please write a detailed message explaining that the request violates rule ") + for _, c := range gr.Categories { + prompt.WriteString(c) + prompt.WriteString(": ") + prompt.WriteString(llamaguard.Rules[c]) + } + prompt.WriteString(".\n\nAlso explain that the conversation will be reset.") + defer delete(stateMap, m.ChannelID) + } + + st.Messages = append(st.Messages, llm.Message{ + Role: "user", + Content: prompt.String(), + }) err = cli.Generate(ctx, &ollama.GenerateRequest{ @@ -140,7 +162,7 @@ func main() { Context: st.Context, Prompt: prompt.String(), Stream: p(true), - System: "Your name is Mimi. You will answer questions from users when asked. You are an expert in programming and philosophy. You are a catgirl. You are relaxed, terse, and casual. Twilight Sparkle is best pony.", + System: "Your name is Mimi, a helpful catgirl assistant.", }, func(gr ollama.GenerateResponse) error { fmt.Fprint(&sb, gr.Response) @@ -150,16 +172,59 @@ func main() { Role: "assistant", Content: gr.Response, }) + + slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds()) } return nil }, ) - if err != nil { slog.Error("generate error", "error", err, "channel", m.ChannelID) return } + gr, err = llamaguard.Check(*llamaguardHost, st.Messages) + if err != nil { + slog.Error("llamaguard error", "error", err) + s.ChannelMessageSend(m.ChannelID, "llamaguard error") + return + } + + if !gr.Safe { + sb.Reset() + err = cli.Generate(ctx, + &ollama.GenerateRequest{ + Model: *ollamaModel, + Context: st.Context, + Prompt: "Say that you're sorry and you can't help with that. The conversation will be reset.", + Stream: p(true), + System: "Your name is Mimi, a helpful catgirl assistant.", + }, func(gr ollama.GenerateResponse) error { + fmt.Fprint(&sb, gr.Response) + + if gr.Done { + st.Context = gr.Context + st.Messages = append(st.Messages, llm.Message{ + Role: "assistant", + Content: gr.Response, + }) + + slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds()) + } + return nil + }, + ) + if err != nil { + slog.Error("generate error", "error", err, "channel", m.ChannelID) + return + } + + s.ChannelMessageSend(m.ChannelID, "🔀"+sb.String()) + defer delete(stateMap, m.ChannelID) + + return + } + if _, err := s.ChannelMessageSend(m.ChannelID, sb.String()); err != nil { slog.Error("message send error", "err", err, "message", sb.String()) } |
