aboutsummaryrefslogtreecommitdiff
path: root/cmd/mimi/modules/discord
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-10-05 13:59:43 -0400
committerXe Iaso <me@xeiaso.net>2024-10-05 13:59:43 -0400
commitfda09b55e316b7f162371cf52bc401fa742913e8 (patch)
tree5bc242b7a47d83d9cfdbfe301d71850b0a64cffd /cmd/mimi/modules/discord
parentaccf38874112579f24bef36119a64f8ed8972cdb (diff)
downloadx-fda09b55e316b7f162371cf52bc401fa742913e8.tar.xz
x-fda09b55e316b7f162371cf52bc401fa742913e8.zip
cmd/mimi: use falin for generating images
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd/mimi/modules/discord')
-rw-r--r--cmd/mimi/modules/discord/jufra/jufra.go11
-rw-r--r--cmd/mimi/modules/discord/jufra/tools.go28
2 files changed, 17 insertions, 22 deletions
diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go
index 78dbd64..6349207 100644
--- a/cmd/mimi/modules/discord/jufra/jufra.go
+++ b/cmd/mimi/modules/discord/jufra/jufra.go
@@ -10,13 +10,15 @@ import (
"encoding/json"
"flag"
"log/slog"
+ "net/http"
"strings"
"sync"
"time"
+ "connectrpc.com/connect"
"github.com/bwmarrin/discordgo"
"within.website/x/cmd/mimi/internal"
- "within.website/x/web/flux"
+ falinconnect "within.website/x/migroserbices/falin/gen/genconnect"
"within.website/x/web/ollama"
"within.website/x/web/ollama/llamaguard"
"within.website/x/web/openai/chatgpt"
@@ -44,7 +46,8 @@ var (
mimiModel = flag.String("jufra-mimi-model", "hermes3", "ollama model tag for mimi")
mimiNames = flag.String("jufra-mimi-names", "mimi", "comma-separated list of names for mimi")
disableLlamaguard = flag.Bool("jufra-unsafe-disable-llamaguard", true, "disable llamaguard")
- fluxHost = flag.String("jufra-flux-host", "http://xe-flux.flycast", "host for flux")
+ falinHost = flag.String("jufra-falin-host", "http://localhost:8080", "host for falin")
+ falinModel = flag.String("jufra-falin-model", "fal-ai/flux-pro/v1.1", "model to use for Falin generations")
contextWindow = flag.Int("jufra-context-window", 32768, "context window size for mimi")
//go:embed system-prompt.txt
@@ -56,7 +59,7 @@ type Module struct {
cli chatgpt.Client
ollama *ollama.Client
lg *ollama.Client
- flux *flux.Client
+ falin falinconnect.ImageServiceClient
convHistory map[string]state
lock sync.Mutex
@@ -73,7 +76,7 @@ func New(sess *discordgo.Session) *Module {
cli: chatgpt.NewClient("").WithBaseURL(internal.OllamaHost()),
ollama: internal.OllamaClient(),
lg: ollama.NewClient(*llamaGuardHost),
- flux: flux.NewClient(*fluxHost),
+ falin: falinconnect.NewImageServiceClient(http.DefaultClient, *falinHost, connect.WithProtoJSON()),
convHistory: make(map[string]state),
}
diff --git a/cmd/mimi/modules/discord/jufra/tools.go b/cmd/mimi/modules/discord/jufra/tools.go
index 8b7d674..46062fe 100644
--- a/cmd/mimi/modules/discord/jufra/tools.go
+++ b/cmd/mimi/modules/discord/jufra/tools.go
@@ -12,9 +12,9 @@ import (
"path/filepath"
"time"
- "github.com/google/uuid"
+ "connectrpc.com/connect"
"within.website/x/llm/codeinterpreter/python"
- "within.website/x/web/flux"
+ falin "within.website/x/migroserbices/falin/gen"
"within.website/x/web/ollama"
)
@@ -163,30 +163,22 @@ func (m *Module) eventuallySendImage(ctx context.Context, channelID string, prom
}
defer os.RemoveAll(tempDir)
- pr, err := m.flux.PredictIdempotent(uuid.NewString(), flux.PredictionRequest{
- Input: flux.Input{
- Prompt: "an anime depiction of " + prompt,
- AspectRatio: "16:9",
- NumInferenceSteps: 50,
- GuidanceScale: 3.5,
- OutputFormat: "webp",
- NumOutputs: 1,
- MaxSequenceLength: 512,
- OutputQuality: 95,
- Seed: &[]int{420}[0],
- },
- })
+ ir, err := m.falin.GenerateImage(ctx, connect.NewRequest(&falin.GenerateImageRequest{
+ Prompt: "an anime depiction of " + prompt,
+ Model: *falinModel,
+ NumImages: 1,
+ }))
if err != nil {
return fmt.Errorf("failed to predict: %w", err)
}
- resp, err := http.Get(pr.Output[0])
+ resp, err := http.Get(ir.Msg.Images[0].Url)
if err != nil {
return fmt.Errorf("failed to get image: %w", err)
}
defer resp.Body.Close()
- imgPath := filepath.Join(tempDir, "image.webp")
+ imgPath := filepath.Join(tempDir, "image.jpg")
imgFile, err := os.Create(imgPath)
if err != nil {
return fmt.Errorf("failed to create image file: %w", err)
@@ -200,7 +192,7 @@ func (m *Module) eventuallySendImage(ctx context.Context, channelID string, prom
return fmt.Errorf("failed to seek image file: %w", err)
}
- msg, err := m.sess.ChannelFileSendWithMessage(channelID, "Here's the image!\n\n```"+prompt+"\n```", "image.webp", imgFile)
+ msg, err := m.sess.ChannelFileSendWithMessage(channelID, "Here's the image!\n\n```"+prompt+"\n```", "image.jpg", imgFile)
if err != nil {
return fmt.Errorf("failed to send image: %w", err)
}