diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-08-01 19:25:46 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-08-01 19:25:46 -0400 |
| commit | bdcd9eb26211bb5b10b1c41a2ffc609933d46033 (patch) | |
| tree | 3a69390ea9c377518e8e4ab44893918c4bebd969 /cmd | |
| parent | 7041695d49386f6e2cf660179c55f7f8a93b5ed3 (diff) | |
| download | x-bdcd9eb26211bb5b10b1c41a2ffc609933d46033.tar.xz x-bdcd9eb26211bb5b10b1c41a2ffc609933d46033.zip | |
cmd/mimi: try having mimi run Python
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/mimi/modules/discord/jufra/jufra.go | 32 | ||||
| -rw-r--r-- | cmd/mimi/modules/discord/jufra/tools.go | 76 |
2 files changed, 108 insertions, 0 deletions
diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go index 43caeb3..6772a07 100644 --- a/cmd/mimi/modules/discord/jufra/jufra.go +++ b/cmd/mimi/modules/discord/jufra/jufra.go @@ -225,6 +225,7 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate Options: map[string]any{ "num_ctx": 131072, }, + Tools: m.getTools(), } resp, err := m.ollama.Chat(context.Background(), cr) @@ -236,6 +237,37 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate conv = append(conv, resp.Message) + if len(resp.Message.ToolCalls) != 0 { + for _, tc := range resp.Message.ToolCalls { + if tc.Name == "run_python_code" { + msg, err := m.runPythonCode(context.Background(), tc) + if err != nil { + slog.Error("error running python code", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error running python code") + return + } + + conv = append(conv, *msg) + + resp, err = m.ollama.Chat(context.Background(), &ollama.CompleteRequest{ + Model: *mimiModel, + Messages: conv, + Options: map[string]any{ + "num_ctx": 131072, + }, + Tools: m.getTools(), + }) + if err != nil { + slog.Error("error chatting", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID) + s.ChannelMessageSend(mc.ChannelID, "error chatting") + return + } + + conv = append(conv, resp.Message) + } + } + } + if !*disableLlamaguard { lgResp, err := m.llamaGuardCheck(context.Background(), "assistant", conv) if err != nil { diff --git a/cmd/mimi/modules/discord/jufra/tools.go b/cmd/mimi/modules/discord/jufra/tools.go index 9fdcaca..4c82c68 100644 --- a/cmd/mimi/modules/discord/jufra/tools.go +++ b/cmd/mimi/modules/discord/jufra/tools.go @@ -1 +1,77 @@ package jufra + +import ( + "context" + "encoding/json" + "errors" + "os" + + "within.website/x/llm/codeinterpreter/python" + "within.website/x/web/ollama" +) + +var normalTools = []ollama.Function{ + { + Name: "run_python_code", + Description: "Run the given Python code in a sandboxed environment", + Parameters: ollama.Param{ + Type: "object", + Properties: ollama.Properties{ + "code": { + Type: "string", + Description: "The Python code to run", + }, + }, + Required: []string{"code"}, + }, + }, +} + +type pythonCodeArgs struct { + Code string `json:"code"` +} + +func (pca *pythonCodeArgs) Valid() error { + if pca.Code == "" { + return errors.New("missing code parameter") + } + + return nil +} + +func (m *Module) runPythonCode(ctx context.Context, tc ollama.ToolCall) (*ollama.Message, error) { + var args pythonCodeArgs + if err := json.Unmarshal(tc.Arguments, &args); err != nil { + return nil, err + } + + tmpdir, err := os.MkdirTemp("", "mimi-python-*") + if err != nil { + return nil, err + } + + defer os.RemoveAll(tmpdir) + + res, err := python.Run(ctx, tmpdir, args.Code) + if err != nil { + return nil, nil + } + + return &ollama.Message{ + Role: "tool", + Content: jsonString(res), + }, nil +} + +func (m *Module) getTools() []ollama.Tool { + var result []ollama.Tool + + for _, tool := range normalTools { + result = append(result, ollama.Tool{ + Type: "function", + Function: tool, + }) + } + + return result +} |
