aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2023-12-08 22:47:08 -0500
committerXe Iaso <me@xeiaso.net>2023-12-08 22:47:08 -0500
commit1403d6130d772b05fdd8807f9aa09474e01051eb (patch)
treea9cb026ef236068732ad1e2df9adc068ebb654e3 /cmd
parentbfd9d18254891113b8ecd2d5fa86a32c0744711f (diff)
downloadx-1403d6130d772b05fdd8807f9aa09474e01051eb.tar.xz
x-1403d6130d772b05fdd8807f9aa09474e01051eb.zip
cmd/mimi: llamaguard filtering
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
-rw-r--r--cmd/llamaguard/main.go58
-rw-r--r--cmd/mimi/main.go83
2 files changed, 132 insertions, 9 deletions
diff --git a/cmd/llamaguard/main.go b/cmd/llamaguard/main.go
new file mode 100644
index 0000000..0603cf4
--- /dev/null
+++ b/cmd/llamaguard/main.go
@@ -0,0 +1,58 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "log/slog"
+ "os"
+ "strings"
+
+ "within.website/x/cmd/mimi/ollama"
+ "within.website/x/internal"
+ "within.website/x/llm"
+ "within.website/x/llm/llamaguard"
+)
+
+var (
+ model = flag.String("model", "xe/llamaguard", "model to use")
+)
+
+func main() {
+ internal.HandleStartup()
+
+ var messages []llm.Message
+
+ if err := json.NewDecoder(os.Stdin).Decode(&messages); err != nil {
+ panic(err)
+ }
+
+ slog.Info("got messages", "num", len(messages))
+
+ out, err := llamaguard.Prompt(messages)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Println(out)
+
+ oc, err := ollama.ClientFromEnvironment()
+ if err != nil {
+ panic(err)
+ }
+
+ var result strings.Builder
+ if err := oc.Generate(context.Background(), &ollama.GenerateRequest{
+ Model: *model,
+ Prompt: out,
+ Raw: true,
+ }, func(gr ollama.GenerateResponse) error {
+ result.WriteString(gr.Response)
+ return nil
+ }); err != nil {
+ panic(err)
+ }
+
+ fmt.Println(strings.TrimSpace(result.String()))
+}
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go
index fc72fa7..2b3d640 100644
--- a/cmd/mimi/main.go
+++ b/cmd/mimi/main.go
@@ -17,6 +17,7 @@ import (
"within.website/x/cmd/mimi/ollama"
"within.website/x/internal"
"within.website/x/llm"
+ "within.website/x/llm/llamaguard"
)
var (
@@ -24,6 +25,7 @@ var (
discordToken = flag.String("discord-token", "", "discord token")
discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild")
discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel")
+ llamaguardHost = flag.String("llamaguard-host", "http://ontos:11434", "llamaguard host")
ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag")
ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host")
openAIKey = flag.String("openai-api-key", "", "openai key")
@@ -122,17 +124,37 @@ func main() {
st, ok := stateMap[m.ChannelID]
if !ok {
- st = &State{
- Messages: []llm.Message{{
- Role: "user",
- Content: prompt.String(),
- }},
- }
+ st = &State{}
+ /* Messages: []llm.Message{{
+ Role: "user",
+ Content: prompt.String(),
+ }},
+ }*/
stateMap[m.ChannelID] = st
}
- fmt.Println(Prompt(st.Messages))
+ gr, err := llamaguard.Check(*llamaguardHost, st.Messages)
+ if err != nil {
+ slog.Error("llamaguard error", "error", err)
+ }
+
+ if !gr.Safe {
+ prompt.Reset()
+ prompt.WriteString("Please write a detailed message explaining that the request violates rule ")
+ for _, c := range gr.Categories {
+ prompt.WriteString(c)
+ prompt.WriteString(": ")
+ prompt.WriteString(llamaguard.Rules[c])
+ }
+ prompt.WriteString(".\n\nAlso explain that the conversation will be reset.")
+ defer delete(stateMap, m.ChannelID)
+ }
+
+ st.Messages = append(st.Messages, llm.Message{
+ Role: "user",
+ Content: prompt.String(),
+ })
err = cli.Generate(ctx,
&ollama.GenerateRequest{
@@ -140,7 +162,7 @@ func main() {
Context: st.Context,
Prompt: prompt.String(),
Stream: p(true),
- System: "Your name is Mimi. You will answer questions from users when asked. You are an expert in programming and philosophy. You are a catgirl. You are relaxed, terse, and casual. Twilight Sparkle is best pony.",
+ System: "Your name is Mimi, a helpful catgirl assistant.",
}, func(gr ollama.GenerateResponse) error {
fmt.Fprint(&sb, gr.Response)
@@ -150,16 +172,59 @@ func main() {
Role: "assistant",
Content: gr.Response,
})
+
+ slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds())
}
return nil
},
)
-
if err != nil {
slog.Error("generate error", "error", err, "channel", m.ChannelID)
return
}
+ gr, err = llamaguard.Check(*llamaguardHost, st.Messages)
+ if err != nil {
+ slog.Error("llamaguard error", "error", err)
+ s.ChannelMessageSend(m.ChannelID, "llamaguard error")
+ return
+ }
+
+ if !gr.Safe {
+ sb.Reset()
+ err = cli.Generate(ctx,
+ &ollama.GenerateRequest{
+ Model: *ollamaModel,
+ Context: st.Context,
+ Prompt: "Say that you're sorry and you can't help with that. The conversation will be reset.",
+ Stream: p(true),
+ System: "Your name is Mimi, a helpful catgirl assistant.",
+ }, func(gr ollama.GenerateResponse) error {
+ fmt.Fprint(&sb, gr.Response)
+
+ if gr.Done {
+ st.Context = gr.Context
+ st.Messages = append(st.Messages, llm.Message{
+ Role: "assistant",
+ Content: gr.Response,
+ })
+
+ slog.Info("generated message", "dur", gr.EvalDuration.String(), "tokens/sec", float64(gr.EvalCount)/gr.EvalDuration.Seconds())
+ }
+ return nil
+ },
+ )
+ if err != nil {
+ slog.Error("generate error", "error", err, "channel", m.ChannelID)
+ return
+ }
+
+ s.ChannelMessageSend(m.ChannelID, "🔀"+sb.String())
+ defer delete(stateMap, m.ChannelID)
+
+ return
+ }
+
if _, err := s.ChannelMessageSend(m.ChannelID, sb.String()); err != nil {
slog.Error("message send error", "err", err, "message", sb.String())
}