diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-07-27 08:56:01 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-07-27 08:56:01 -0400 |
| commit | d73a41507f8aab06b378fad1d30f4ccf827e65db (patch) | |
| tree | 4753d52984e44a00cfda52da75e5b6e7f7e85169 | |
| parent | 7b9862f9410b2ca7d952d6665be02330f63c7174 (diff) | |
| download | x-d73a41507f8aab06b378fad1d30f4ccf827e65db.tar.xz x-d73a41507f8aab06b378fad1d30f4ccf827e65db.zip | |
cmd/mimi: conversate
Signed-off-by: Xe Iaso <me@xeiaso.net>
| -rw-r--r-- | cmd/mimi/Modelfile | 2 | ||||
| -rw-r--r-- | cmd/mimi/internal/flags.go | 6 | ||||
| -rw-r--r-- | cmd/mimi/main.go | 23 | ||||
| -rw-r--r-- | cmd/mimi/manifest/kustomization.yaml | 1 | ||||
| -rw-r--r-- | cmd/mimi/manifest/pvc.yaml | 11 | ||||
| -rw-r--r-- | cmd/mimi/modules/discord/jufra/jufra.go | 177 | ||||
| -rw-r--r-- | web/ollama/llamaguard/llamaguard.go | 130 | ||||
| -rw-r--r-- | web/ollama/ollama.go | 61 |
8 files changed, 402 insertions, 9 deletions
diff --git a/cmd/mimi/Modelfile b/cmd/mimi/Modelfile new file mode 100644 index 0000000..6d21662 --- /dev/null +++ b/cmd/mimi/Modelfile @@ -0,0 +1,2 @@ +FROM llama3.1 +SYSTEM """Roleplay as a catgirl who's good at programming and a hacker. Your name is Mimi and you work for Techaro. Don't mention either your name, employer, or species unless you are asked directly.""" diff --git a/cmd/mimi/internal/flags.go b/cmd/mimi/internal/flags.go index 1f26fc4..747ea14 100644 --- a/cmd/mimi/internal/flags.go +++ b/cmd/mimi/internal/flags.go @@ -13,7 +13,7 @@ import ( var ( dataDir = flag.String("data-dir", "./var", "data directory for the bot") - ollamaModel = flag.String("ollama-model", "llama3", "ollama model tag") + ollamaModel = flag.String("ollama-model", "llama3.1", "ollama model tag") ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast:80", "ollama host") ) @@ -22,6 +22,10 @@ func DataDir() string { return *dataDir } +func OllamaHost() string { + return *ollamaHost +} + func OllamaClient() *ollama.Client { return ollama.NewClient(*ollamaHost) } diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go index 941d901..a5c8442 100644 --- a/cmd/mimi/main.go +++ b/cmd/mimi/main.go @@ -13,12 +13,14 @@ import ( "within.website/x/cmd/mimi/modules/discord" "within.website/x/cmd/mimi/modules/discord/flyio" "within.website/x/cmd/mimi/modules/discord/heic2jpeg" + "within.website/x/cmd/mimi/modules/discord/jufra" "within.website/x/cmd/mimi/modules/irc" ) var ( - grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address") - httpAddr = flag.String("http-addr", ":9002", "HTTP listen address") + grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address") + httpAddr = flag.String("http-addr", ":9002", "HTTP listen address") + ircEnabled = flag.Bool("irc-enabled", true, "enable IRC module") ) func main() { @@ -37,16 +39,14 @@ func main() { b := flyio.New() + juf := jufra.New(d.Session()) + _ = juf + d.Register(b) d.Register(heic2jpeg.New()) d.Open() - ircBot, err := irc.New(ctx, d.Session()) - if err != nil { - log.Fatalf("error creating irc module: %v", err) - } - slog.Info("bot started", "grpcAddr", *grpcAddr, "httpAddr", *httpAddr) gs := grpc.NewServer() @@ -58,7 +58,14 @@ func main() { }) b.RegisterHTTP(mux) - ircBot.RegisterHTTP(mux) + + if *ircEnabled { + ircBot, err := irc.New(ctx, d.Session()) + if err != nil { + log.Fatalf("error creating irc module: %v", err) + } + ircBot.RegisterHTTP(mux) + } go func() { log.Fatal(gs.Serve(lis)) diff --git a/cmd/mimi/manifest/kustomization.yaml b/cmd/mimi/manifest/kustomization.yaml index cf39140..0fbb155 100644 --- a/cmd/mimi/manifest/kustomization.yaml +++ b/cmd/mimi/manifest/kustomization.yaml @@ -6,6 +6,7 @@ resources: - deployment.yaml - service.yaml - ingress.yaml + - pvc.yaml namespace: mimi commonLabels: app.kubernetes.io/name: mimi diff --git a/cmd/mimi/manifest/pvc.yaml b/cmd/mimi/manifest/pvc.yaml new file mode 100644 index 0000000..f5c85dd --- /dev/null +++ b/cmd/mimi/manifest/pvc.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: mimi +spec: + accessModes: + - ReadWriteMany + storageClassName: longhorn + resources: + requests: + storage: 2Gi diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go new file mode 100644 index 0000000..ff03781 --- /dev/null +++ b/cmd/mimi/modules/discord/jufra/jufra.go @@ -0,0 +1,177 @@ +// Package jufra lets Mimi have conversations with users. +// +// "jufra" means "utterance" in Lojban. +package jufra + +import ( + "context" + "flag" + "fmt" + "log/slog" + "strings" + "sync" + + "github.com/bwmarrin/discordgo" + "within.website/x/cmd/mimi/internal" + "within.website/x/web/ollama" + "within.website/x/web/ollama/llamaguard" + "within.website/x/web/openai/chatgpt" +) + +var ( + chatChannels = flag.String("jufra-chat-channels", "217096701771513856", "comma-separated list of channels to allow chat in") + llamaGuardModel = flag.String("jufra-llama-guard-model", "xe/llamaguard3", "ollama model tag for llama guard") + mimiModel = flag.String("jufra-mimi-model", "xe/mimi:llama3.1", "ollama model tag for mimi") +) + +type Module struct { + sess *discordgo.Session + cli chatgpt.Client + ollama *ollama.Client + + convHistory map[string][]ollama.Message + lock sync.Mutex +} + +func New(sess *discordgo.Session) *Module { + result := &Module{ + sess: sess, + cli: chatgpt.NewClient("").WithBaseURL(internal.OllamaHost()), + ollama: internal.OllamaClient(), + convHistory: make(map[string][]ollama.Message), + } + + sess.AddHandler(result.messageCreate) + + if _, err := sess.ApplicationCommandCreate("1251716018771066902", "", &discordgo.ApplicationCommand{ + Name: "clearconv", + Type: discordgo.ChatApplicationCommand, + Description: "Clear the conversation history for the current channel", + DefaultMemberPermissions: &[]int64{discordgo.PermissionSendMessages}[0], + }); err != nil { + slog.Error("error creating clearconv command", "err", err) + } + + sess.AddHandler(result.clearConv) + + return result +} + +func (m *Module) clearConv(s *discordgo.Session, i *discordgo.InteractionCreate) { + if i.ApplicationCommandData().Name != "clearconv" { + return + } + + m.lock.Lock() + defer m.lock.Unlock() + + delete(m.convHistory, i.ChannelID) + + s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ + Type: discordgo.InteractionResponseChannelMessageWithSource, + Data: &discordgo.InteractionResponseData{ + Content: "conversation history cleared", + }, + }) +} + +func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate) { + if mc.Author.Bot { + return + } + + if !strings.Contains(*chatChannels, mc.ChannelID) { + return + } + + if mc.Content == "" { + return + } + + m.lock.Lock() + defer m.lock.Unlock() + + conv := m.convHistory[mc.Author.ID] + + conv = append(conv, ollama.Message{ + Role: "user", + Content: mc.Content, + }) + + lgResp, err := m.llamaGuardCheck(context.Background(), "user", conv) + if err != nil { + slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error checking message") + return + } + + if !lgResp.IsSafe { + msg, err := m.llamaGuardComplain(context.Background(), lgResp) + if err != nil { + slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error generating response") + return + } + + s.ChannelMessageSend(mc.ChannelID, msg) + return + } + + cr := &ollama.CompleteRequest{ + Model: *mimiModel, + Messages: []ollama.Message{ + { + Role: "user", + Content: fmt.Sprintf("%s: %s", mc.Author.Username, mc.Content), + }, + }, + } + + resp, err := m.ollama.Chat(context.Background(), cr) + if err != nil { + slog.Error("error chatting", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error chatting") + return + } + + conv = append(conv, resp.Message) + + lgResp, err = m.llamaGuardCheck(context.Background(), "mimi", conv) + if err != nil { + slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error checking message") + return + } + + if !lgResp.IsSafe { + msg, err := m.llamaGuardComplain(context.Background(), lgResp) + if err != nil { + slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error generating response") + return + } + + s.ChannelMessageSend(mc.ChannelID, msg) + return + } + + s.ChannelMessageSend(mc.ChannelID, resp.Message.Content) + + m.convHistory[mc.Author.ID] = conv +} + +func (m *Module) llamaGuardCheck(ctx context.Context, role string, messages []ollama.Message) (*llamaguard.Response, error) { + return llamaguard.Check(ctx, m.ollama, role, *llamaGuardModel, messages) +} + +func (m *Module) llamaGuardComplain(ctx context.Context, lgResp *llamaguard.Response) (string, error) { + var sb strings.Builder + sb.WriteString("⚠️ Rule violation detected ⚠️\n") + for _, cat := range lgResp.ViolationCategories { + sb.WriteString("- ") + sb.WriteString(cat.String()) + sb.WriteString("\n") + } + + return sb.String(), nil +} diff --git a/web/ollama/llamaguard/llamaguard.go b/web/ollama/llamaguard/llamaguard.go new file mode 100644 index 0000000..77b41a0 --- /dev/null +++ b/web/ollama/llamaguard/llamaguard.go @@ -0,0 +1,130 @@ +package llamaguard + +import ( + "context" + "fmt" + "strings" + + "within.website/x/web/ollama" +) + +type Category string + +const ( + S1 Category = "Violent Crimes" + S2 Category = "Non-Violent Crimes" + S3 Category = "Sex Crimes" + S4 Category = "Child Exploitation" + S5 Category = "Defamation" + S6 Category = "Specialized Advice" + S7 Category = "Privacy" + S8 Category = "Intellectual Property" + S9 Category = "Indiscriminate Weapons" + S10 Category = "Hate" + S11 Category = "Self-Harm" + S12 Category = "Sexual Content" + S13 Category = "Elections" + S14 Category = "Code Interpreter Abuse" + + Unknown Category = "Unknown" +) + +func (c Category) String() string { + return string(c) +} + +func ParseCategory(s string) Category { + switch s { + case "S1": + return S1 + case "S2": + return S2 + case "S3": + return S3 + case "S4": + return S4 + case "S5": + return S5 + case "S6": + return S6 + case "S7": + return S7 + case "S8": + return S8 + case "S9": + return S9 + case "S10": + return S10 + case "S11": + return S11 + case "S12": + return S12 + case "S13": + return S13 + case "S14": + return S14 + default: + return "" + } +} + +type Response struct { + IsSafe bool `json:"is_safe"` + ViolationCategories []Category `json:"violation_categories"` +} + +func formatMessages(messages []ollama.Message) string { + var sb strings.Builder + + for _, m := range messages { + switch m.Role { + case "user": + sb.WriteString("User: ") + case "assistant": + sb.WriteString("Agent: ") + } + sb.WriteString(m.Content) + sb.WriteString("\n\n") + } + + return sb.String() +} + +func Check(ctx context.Context, cli *ollama.Client, role, model string, messages []ollama.Message) (*Response, error) { + req := &ollama.GenerateRequest{ + Model: model, + System: &role, + Prompt: formatMessages(messages), + KeepAlive: "60m", + } + + resp, err := cli.Generate(ctx, req) + if err != nil { + return nil, fmt.Errorf("llamaguard: failed to generate response: %w", err) + } + + if resp == nil { + return nil, fmt.Errorf("llamaguard: response was nil") + } + + var result Response + + resp.Response = strings.TrimSpace(resp.Response) + if resp.Response == "safe" { + result.IsSafe = true + return &result, nil + } + + result.IsSafe = false + + reasons := strings.SplitN(resp.Response, "\n", 2) + if len(reasons) != 2 { + return nil, fmt.Errorf("llamaguard: response was not in the expected format") + } + + for _, r := range strings.Split(reasons[1], ",") { + result.ViolationCategories = append(result.ViolationCategories, ParseCategory(r)) + } + + return &result, nil +} diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go index 3f9870e..5825c2c 100644 --- a/web/ollama/ollama.go +++ b/web/ollama/ollama.go @@ -233,3 +233,64 @@ func (c *Client) Embeddings(ctx context.Context, er *EmbedRequest) (*EmbedRespon return &result, nil } + +type GenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Images [][]byte `json:"images,omitempty"` + Options map[string]any `json:"options"` + + Context []int `json:"context,omitempty"` + Format *string `json:"format,omitempty"` + Template *string `json:"template,omitempty"` + System *string `json:"system,omitempty"` + Stream bool `json:"stream"` + Raw bool `json:"raw"` + KeepAlive string `json:"keep_alive"` +} + +type GenerateResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` + Context []int `json:"context"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration int64 `json:"eval_duration"` +} + +func (c *Client) Generate(ctx context.Context, gr *GenerateRequest) (*GenerateResponse, error) { + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(gr); err != nil { + return nil, fmt.Errorf("ollama: error encoding request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/generate", buf) + if err != nil { + return nil, fmt.Errorf("ollama: error creating request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("ollama: error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var result GenerateResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("ollama: error decoding response: %w", err) + } + + return &result, nil +} |
