aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-07-27 08:56:01 -0400
committerXe Iaso <me@xeiaso.net>2024-07-27 08:56:01 -0400
commitd73a41507f8aab06b378fad1d30f4ccf827e65db (patch)
tree4753d52984e44a00cfda52da75e5b6e7f7e85169
parent7b9862f9410b2ca7d952d6665be02330f63c7174 (diff)
downloadx-d73a41507f8aab06b378fad1d30f4ccf827e65db.tar.xz
x-d73a41507f8aab06b378fad1d30f4ccf827e65db.zip
cmd/mimi: conversate
Signed-off-by: Xe Iaso <me@xeiaso.net>
-rw-r--r--cmd/mimi/Modelfile2
-rw-r--r--cmd/mimi/internal/flags.go6
-rw-r--r--cmd/mimi/main.go23
-rw-r--r--cmd/mimi/manifest/kustomization.yaml1
-rw-r--r--cmd/mimi/manifest/pvc.yaml11
-rw-r--r--cmd/mimi/modules/discord/jufra/jufra.go177
-rw-r--r--web/ollama/llamaguard/llamaguard.go130
-rw-r--r--web/ollama/ollama.go61
8 files changed, 402 insertions, 9 deletions
diff --git a/cmd/mimi/Modelfile b/cmd/mimi/Modelfile
new file mode 100644
index 0000000..6d21662
--- /dev/null
+++ b/cmd/mimi/Modelfile
@@ -0,0 +1,2 @@
+FROM llama3.1
+SYSTEM """Roleplay as a catgirl who's good at programming and a hacker. Your name is Mimi and you work for Techaro. Don't mention either your name, employer, or species unless you are asked directly."""
diff --git a/cmd/mimi/internal/flags.go b/cmd/mimi/internal/flags.go
index 1f26fc4..747ea14 100644
--- a/cmd/mimi/internal/flags.go
+++ b/cmd/mimi/internal/flags.go
@@ -13,7 +13,7 @@ import (
var (
dataDir = flag.String("data-dir", "./var", "data directory for the bot")
- ollamaModel = flag.String("ollama-model", "llama3", "ollama model tag")
+ ollamaModel = flag.String("ollama-model", "llama3.1", "ollama model tag")
ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast:80", "ollama host")
)
@@ -22,6 +22,10 @@ func DataDir() string {
return *dataDir
}
+func OllamaHost() string {
+ return *ollamaHost
+}
+
func OllamaClient() *ollama.Client {
return ollama.NewClient(*ollamaHost)
}
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go
index 941d901..a5c8442 100644
--- a/cmd/mimi/main.go
+++ b/cmd/mimi/main.go
@@ -13,12 +13,14 @@ import (
"within.website/x/cmd/mimi/modules/discord"
"within.website/x/cmd/mimi/modules/discord/flyio"
"within.website/x/cmd/mimi/modules/discord/heic2jpeg"
+ "within.website/x/cmd/mimi/modules/discord/jufra"
"within.website/x/cmd/mimi/modules/irc"
)
var (
- grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address")
- httpAddr = flag.String("http-addr", ":9002", "HTTP listen address")
+ grpcAddr = flag.String("grpc-addr", ":9001", "GRPC listen address")
+ httpAddr = flag.String("http-addr", ":9002", "HTTP listen address")
+ ircEnabled = flag.Bool("irc-enabled", true, "enable IRC module")
)
func main() {
@@ -37,16 +39,14 @@ func main() {
b := flyio.New()
+ juf := jufra.New(d.Session())
+ _ = juf
+
d.Register(b)
d.Register(heic2jpeg.New())
d.Open()
- ircBot, err := irc.New(ctx, d.Session())
- if err != nil {
- log.Fatalf("error creating irc module: %v", err)
- }
-
slog.Info("bot started", "grpcAddr", *grpcAddr, "httpAddr", *httpAddr)
gs := grpc.NewServer()
@@ -58,7 +58,14 @@ func main() {
})
b.RegisterHTTP(mux)
- ircBot.RegisterHTTP(mux)
+
+ if *ircEnabled {
+ ircBot, err := irc.New(ctx, d.Session())
+ if err != nil {
+ log.Fatalf("error creating irc module: %v", err)
+ }
+ ircBot.RegisterHTTP(mux)
+ }
go func() {
log.Fatal(gs.Serve(lis))
diff --git a/cmd/mimi/manifest/kustomization.yaml b/cmd/mimi/manifest/kustomization.yaml
index cf39140..0fbb155 100644
--- a/cmd/mimi/manifest/kustomization.yaml
+++ b/cmd/mimi/manifest/kustomization.yaml
@@ -6,6 +6,7 @@ resources:
- deployment.yaml
- service.yaml
- ingress.yaml
+ - pvc.yaml
namespace: mimi
commonLabels:
app.kubernetes.io/name: mimi
diff --git a/cmd/mimi/manifest/pvc.yaml b/cmd/mimi/manifest/pvc.yaml
new file mode 100644
index 0000000..f5c85dd
--- /dev/null
+++ b/cmd/mimi/manifest/pvc.yaml
@@ -0,0 +1,11 @@
+apiVersion: v1
+kind: PersistentVolumeClaim
+metadata:
+ name: mimi
+spec:
+ accessModes:
+ - ReadWriteMany
+ storageClassName: longhorn
+ resources:
+ requests:
+ storage: 2Gi
diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go
new file mode 100644
index 0000000..ff03781
--- /dev/null
+++ b/cmd/mimi/modules/discord/jufra/jufra.go
@@ -0,0 +1,177 @@
+// Package jufra lets Mimi have conversations with users.
+//
+// "jufra" means "utterance" in Lojban.
+package jufra
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "log/slog"
+ "strings"
+ "sync"
+
+ "github.com/bwmarrin/discordgo"
+ "within.website/x/cmd/mimi/internal"
+ "within.website/x/web/ollama"
+ "within.website/x/web/ollama/llamaguard"
+ "within.website/x/web/openai/chatgpt"
+)
+
+var (
+ chatChannels = flag.String("jufra-chat-channels", "217096701771513856", "comma-separated list of channels to allow chat in")
+ llamaGuardModel = flag.String("jufra-llama-guard-model", "xe/llamaguard3", "ollama model tag for llama guard")
+ mimiModel = flag.String("jufra-mimi-model", "xe/mimi:llama3.1", "ollama model tag for mimi")
+)
+
+type Module struct {
+ sess *discordgo.Session
+ cli chatgpt.Client
+ ollama *ollama.Client
+
+ convHistory map[string][]ollama.Message
+ lock sync.Mutex
+}
+
+func New(sess *discordgo.Session) *Module {
+ result := &Module{
+ sess: sess,
+ cli: chatgpt.NewClient("").WithBaseURL(internal.OllamaHost()),
+ ollama: internal.OllamaClient(),
+ convHistory: make(map[string][]ollama.Message),
+ }
+
+ sess.AddHandler(result.messageCreate)
+
+ if _, err := sess.ApplicationCommandCreate("1251716018771066902", "", &discordgo.ApplicationCommand{
+ Name: "clearconv",
+ Type: discordgo.ChatApplicationCommand,
+ Description: "Clear the conversation history for the current channel",
+ DefaultMemberPermissions: &[]int64{discordgo.PermissionSendMessages}[0],
+ }); err != nil {
+ slog.Error("error creating clearconv command", "err", err)
+ }
+
+ sess.AddHandler(result.clearConv)
+
+ return result
+}
+
+func (m *Module) clearConv(s *discordgo.Session, i *discordgo.InteractionCreate) {
+ if i.ApplicationCommandData().Name != "clearconv" {
+ return
+ }
+
+ m.lock.Lock()
+ defer m.lock.Unlock()
+
+ delete(m.convHistory, i.ChannelID)
+
+ s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
+ Type: discordgo.InteractionResponseChannelMessageWithSource,
+ Data: &discordgo.InteractionResponseData{
+ Content: "conversation history cleared",
+ },
+ })
+}
+
+func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate) {
+ if mc.Author.Bot {
+ return
+ }
+
+ if !strings.Contains(*chatChannels, mc.ChannelID) {
+ return
+ }
+
+ if mc.Content == "" {
+ return
+ }
+
+ m.lock.Lock()
+ defer m.lock.Unlock()
+
+ conv := m.convHistory[mc.Author.ID]
+
+ conv = append(conv, ollama.Message{
+ Role: "user",
+ Content: mc.Content,
+ })
+
+ lgResp, err := m.llamaGuardCheck(context.Background(), "user", conv)
+ if err != nil {
+ slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error checking message")
+ return
+ }
+
+ if !lgResp.IsSafe {
+ msg, err := m.llamaGuardComplain(context.Background(), lgResp)
+ if err != nil {
+ slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error generating response")
+ return
+ }
+
+ s.ChannelMessageSend(mc.ChannelID, msg)
+ return
+ }
+
+ cr := &ollama.CompleteRequest{
+ Model: *mimiModel,
+ Messages: []ollama.Message{
+ {
+ Role: "user",
+ Content: fmt.Sprintf("%s: %s", mc.Author.Username, mc.Content),
+ },
+ },
+ }
+
+ resp, err := m.ollama.Chat(context.Background(), cr)
+ if err != nil {
+ slog.Error("error chatting", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error chatting")
+ return
+ }
+
+ conv = append(conv, resp.Message)
+
+ lgResp, err = m.llamaGuardCheck(context.Background(), "mimi", conv)
+ if err != nil {
+ slog.Error("error checking message", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error checking message")
+ return
+ }
+
+ if !lgResp.IsSafe {
+ msg, err := m.llamaGuardComplain(context.Background(), lgResp)
+ if err != nil {
+ slog.Error("error generating response", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error generating response")
+ return
+ }
+
+ s.ChannelMessageSend(mc.ChannelID, msg)
+ return
+ }
+
+ s.ChannelMessageSend(mc.ChannelID, resp.Message.Content)
+
+ m.convHistory[mc.Author.ID] = conv
+}
+
+func (m *Module) llamaGuardCheck(ctx context.Context, role string, messages []ollama.Message) (*llamaguard.Response, error) {
+ return llamaguard.Check(ctx, m.ollama, role, *llamaGuardModel, messages)
+}
+
+func (m *Module) llamaGuardComplain(ctx context.Context, lgResp *llamaguard.Response) (string, error) {
+ var sb strings.Builder
+ sb.WriteString("⚠️ Rule violation detected ⚠️\n")
+ for _, cat := range lgResp.ViolationCategories {
+ sb.WriteString("- ")
+ sb.WriteString(cat.String())
+ sb.WriteString("\n")
+ }
+
+ return sb.String(), nil
+}
diff --git a/web/ollama/llamaguard/llamaguard.go b/web/ollama/llamaguard/llamaguard.go
new file mode 100644
index 0000000..77b41a0
--- /dev/null
+++ b/web/ollama/llamaguard/llamaguard.go
@@ -0,0 +1,130 @@
+package llamaguard
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "within.website/x/web/ollama"
+)
+
+type Category string
+
+const (
+ S1 Category = "Violent Crimes"
+ S2 Category = "Non-Violent Crimes"
+ S3 Category = "Sex Crimes"
+ S4 Category = "Child Exploitation"
+ S5 Category = "Defamation"
+ S6 Category = "Specialized Advice"
+ S7 Category = "Privacy"
+ S8 Category = "Intellectual Property"
+ S9 Category = "Indiscriminate Weapons"
+ S10 Category = "Hate"
+ S11 Category = "Self-Harm"
+ S12 Category = "Sexual Content"
+ S13 Category = "Elections"
+ S14 Category = "Code Interpreter Abuse"
+
+ Unknown Category = "Unknown"
+)
+
+func (c Category) String() string {
+ return string(c)
+}
+
+func ParseCategory(s string) Category {
+ switch s {
+ case "S1":
+ return S1
+ case "S2":
+ return S2
+ case "S3":
+ return S3
+ case "S4":
+ return S4
+ case "S5":
+ return S5
+ case "S6":
+ return S6
+ case "S7":
+ return S7
+ case "S8":
+ return S8
+ case "S9":
+ return S9
+ case "S10":
+ return S10
+ case "S11":
+ return S11
+ case "S12":
+ return S12
+ case "S13":
+ return S13
+ case "S14":
+ return S14
+ default:
+ return ""
+ }
+}
+
+type Response struct {
+ IsSafe bool `json:"is_safe"`
+ ViolationCategories []Category `json:"violation_categories"`
+}
+
+func formatMessages(messages []ollama.Message) string {
+ var sb strings.Builder
+
+ for _, m := range messages {
+ switch m.Role {
+ case "user":
+ sb.WriteString("User: ")
+ case "assistant":
+ sb.WriteString("Agent: ")
+ }
+ sb.WriteString(m.Content)
+ sb.WriteString("\n\n")
+ }
+
+ return sb.String()
+}
+
+func Check(ctx context.Context, cli *ollama.Client, role, model string, messages []ollama.Message) (*Response, error) {
+ req := &ollama.GenerateRequest{
+ Model: model,
+ System: &role,
+ Prompt: formatMessages(messages),
+ KeepAlive: "60m",
+ }
+
+ resp, err := cli.Generate(ctx, req)
+ if err != nil {
+ return nil, fmt.Errorf("llamaguard: failed to generate response: %w", err)
+ }
+
+ if resp == nil {
+ return nil, fmt.Errorf("llamaguard: response was nil")
+ }
+
+ var result Response
+
+ resp.Response = strings.TrimSpace(resp.Response)
+ if resp.Response == "safe" {
+ result.IsSafe = true
+ return &result, nil
+ }
+
+ result.IsSafe = false
+
+ reasons := strings.SplitN(resp.Response, "\n", 2)
+ if len(reasons) != 2 {
+ return nil, fmt.Errorf("llamaguard: response was not in the expected format")
+ }
+
+ for _, r := range strings.Split(reasons[1], ",") {
+ result.ViolationCategories = append(result.ViolationCategories, ParseCategory(r))
+ }
+
+ return &result, nil
+}
diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go
index 3f9870e..5825c2c 100644
--- a/web/ollama/ollama.go
+++ b/web/ollama/ollama.go
@@ -233,3 +233,64 @@ func (c *Client) Embeddings(ctx context.Context, er *EmbedRequest) (*EmbedRespon
return &result, nil
}
+
+type GenerateRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Images [][]byte `json:"images,omitempty"`
+ Options map[string]any `json:"options"`
+
+ Context []int `json:"context,omitempty"`
+ Format *string `json:"format,omitempty"`
+ Template *string `json:"template,omitempty"`
+ System *string `json:"system,omitempty"`
+ Stream bool `json:"stream"`
+ Raw bool `json:"raw"`
+ KeepAlive string `json:"keep_alive"`
+}
+
+type GenerateResponse struct {
+ Model string `json:"model"`
+ CreatedAt time.Time `json:"created_at"`
+ Response string `json:"response"`
+ Done bool `json:"done"`
+ Context []int `json:"context"`
+ TotalDuration int64 `json:"total_duration"`
+ LoadDuration int64 `json:"load_duration"`
+ PromptEvalCount int `json:"prompt_eval_count"`
+ PromptEvalDuration int64 `json:"prompt_eval_duration"`
+ EvalCount int `json:"eval_count"`
+ EvalDuration int64 `json:"eval_duration"`
+}
+
+func (c *Client) Generate(ctx context.Context, gr *GenerateRequest) (*GenerateResponse, error) {
+ buf := &bytes.Buffer{}
+ if err := json.NewEncoder(buf).Encode(gr); err != nil {
+ return nil, fmt.Errorf("ollama: error encoding request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/generate", buf)
+ if err != nil {
+ return nil, fmt.Errorf("ollama: error creating request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("ollama: error making request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var result GenerateResponse
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, fmt.Errorf("ollama: error decoding response: %w", err)
+ }
+
+ return &result, nil
+}