aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-08-01 19:25:46 -0400
committerXe Iaso <me@xeiaso.net>2024-08-01 19:25:46 -0400
commitbdcd9eb26211bb5b10b1c41a2ffc609933d46033 (patch)
tree3a69390ea9c377518e8e4ab44893918c4bebd969 /cmd
parent7041695d49386f6e2cf660179c55f7f8a93b5ed3 (diff)
downloadx-bdcd9eb26211bb5b10b1c41a2ffc609933d46033.tar.xz
x-bdcd9eb26211bb5b10b1c41a2ffc609933d46033.zip
cmd/mimi: try having mimi run Python
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
-rw-r--r--cmd/mimi/modules/discord/jufra/jufra.go32
-rw-r--r--cmd/mimi/modules/discord/jufra/tools.go76
2 files changed, 108 insertions, 0 deletions
diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go
index 43caeb3..6772a07 100644
--- a/cmd/mimi/modules/discord/jufra/jufra.go
+++ b/cmd/mimi/modules/discord/jufra/jufra.go
@@ -225,6 +225,7 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate
Options: map[string]any{
"num_ctx": 131072,
},
+ Tools: m.getTools(),
}
resp, err := m.ollama.Chat(context.Background(), cr)
@@ -236,6 +237,37 @@ func (m *Module) messageCreate(s *discordgo.Session, mc *discordgo.MessageCreate
conv = append(conv, resp.Message)
+ if len(resp.Message.ToolCalls) != 0 {
+ for _, tc := range resp.Message.ToolCalls {
+ if tc.Name == "run_python_code" {
+ msg, err := m.runPythonCode(context.Background(), tc)
+ if err != nil {
+ slog.Error("error running python code", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error running python code")
+ return
+ }
+
+ conv = append(conv, *msg)
+
+ resp, err = m.ollama.Chat(context.Background(), &ollama.CompleteRequest{
+ Model: *mimiModel,
+ Messages: conv,
+ Options: map[string]any{
+ "num_ctx": 131072,
+ },
+ Tools: m.getTools(),
+ })
+ if err != nil {
+ slog.Error("error chatting", "err", err, "message_id", mc.ID, "channel_id", mc.ChannelID)
+ s.ChannelMessageSend(mc.ChannelID, "error chatting")
+ return
+ }
+
+ conv = append(conv, resp.Message)
+ }
+ }
+ }
+
if !*disableLlamaguard {
lgResp, err := m.llamaGuardCheck(context.Background(), "assistant", conv)
if err != nil {
diff --git a/cmd/mimi/modules/discord/jufra/tools.go b/cmd/mimi/modules/discord/jufra/tools.go
index 9fdcaca..4c82c68 100644
--- a/cmd/mimi/modules/discord/jufra/tools.go
+++ b/cmd/mimi/modules/discord/jufra/tools.go
@@ -1 +1,77 @@
package jufra
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "os"
+
+ "within.website/x/llm/codeinterpreter/python"
+ "within.website/x/web/ollama"
+)
+
+var normalTools = []ollama.Function{
+ {
+ Name: "run_python_code",
+ Description: "Run the given Python code in a sandboxed environment",
+ Parameters: ollama.Param{
+ Type: "object",
+ Properties: ollama.Properties{
+ "code": {
+ Type: "string",
+ Description: "The Python code to run",
+ },
+ },
+ Required: []string{"code"},
+ },
+ },
+}
+
+type pythonCodeArgs struct {
+ Code string `json:"code"`
+}
+
+func (pca *pythonCodeArgs) Valid() error {
+ if pca.Code == "" {
+ return errors.New("missing code parameter")
+ }
+
+ return nil
+}
+
+func (m *Module) runPythonCode(ctx context.Context, tc ollama.ToolCall) (*ollama.Message, error) {
+ var args pythonCodeArgs
+ if err := json.Unmarshal(tc.Arguments, &args); err != nil {
+ return nil, err
+ }
+
+ tmpdir, err := os.MkdirTemp("", "mimi-python-*")
+ if err != nil {
+ return nil, err
+ }
+
+ defer os.RemoveAll(tmpdir)
+
+ res, err := python.Run(ctx, tmpdir, args.Code)
+ if err != nil {
+ return nil, nil
+ }
+
+ return &ollama.Message{
+ Role: "tool",
+ Content: jsonString(res),
+ }, nil
+}
+
+func (m *Module) getTools() []ollama.Tool {
+ var result []ollama.Tool
+
+ for _, tool := range normalTools {
+ result = append(result, ollama.Tool{
+ Type: "function",
+ Function: tool,
+ })
+ }
+
+ return result
+}