diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-10-05 13:59:43 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-10-05 13:59:43 -0400 |
| commit | fda09b55e316b7f162371cf52bc401fa742913e8 (patch) | |
| tree | 5bc242b7a47d83d9cfdbfe301d71850b0a64cffd /cmd/mimi/modules/discord | |
| parent | accf38874112579f24bef36119a64f8ed8972cdb (diff) | |
| download | x-fda09b55e316b7f162371cf52bc401fa742913e8.tar.xz x-fda09b55e316b7f162371cf52bc401fa742913e8.zip | |
cmd/mimi: use falin for generating images
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd/mimi/modules/discord')
| -rw-r--r-- | cmd/mimi/modules/discord/jufra/jufra.go | 11 | ||||
| -rw-r--r-- | cmd/mimi/modules/discord/jufra/tools.go | 28 |
2 files changed, 17 insertions, 22 deletions
diff --git a/cmd/mimi/modules/discord/jufra/jufra.go b/cmd/mimi/modules/discord/jufra/jufra.go index 78dbd64..6349207 100644 --- a/cmd/mimi/modules/discord/jufra/jufra.go +++ b/cmd/mimi/modules/discord/jufra/jufra.go @@ -10,13 +10,15 @@ import ( "encoding/json" "flag" "log/slog" + "net/http" "strings" "sync" "time" + "connectrpc.com/connect" "github.com/bwmarrin/discordgo" "within.website/x/cmd/mimi/internal" - "within.website/x/web/flux" + falinconnect "within.website/x/migroserbices/falin/gen/genconnect" "within.website/x/web/ollama" "within.website/x/web/ollama/llamaguard" "within.website/x/web/openai/chatgpt" @@ -44,7 +46,8 @@ var ( mimiModel = flag.String("jufra-mimi-model", "hermes3", "ollama model tag for mimi") mimiNames = flag.String("jufra-mimi-names", "mimi", "comma-separated list of names for mimi") disableLlamaguard = flag.Bool("jufra-unsafe-disable-llamaguard", true, "disable llamaguard") - fluxHost = flag.String("jufra-flux-host", "http://xe-flux.flycast", "host for flux") + falinHost = flag.String("jufra-falin-host", "http://localhost:8080", "host for falin") + falinModel = flag.String("jufra-falin-model", "fal-ai/flux-pro/v1.1", "model to use for Falin generations") contextWindow = flag.Int("jufra-context-window", 32768, "context window size for mimi") //go:embed system-prompt.txt @@ -56,7 +59,7 @@ type Module struct { cli chatgpt.Client ollama *ollama.Client lg *ollama.Client - flux *flux.Client + falin falinconnect.ImageServiceClient convHistory map[string]state lock sync.Mutex @@ -73,7 +76,7 @@ func New(sess *discordgo.Session) *Module { cli: chatgpt.NewClient("").WithBaseURL(internal.OllamaHost()), ollama: internal.OllamaClient(), lg: ollama.NewClient(*llamaGuardHost), - flux: flux.NewClient(*fluxHost), + falin: falinconnect.NewImageServiceClient(http.DefaultClient, *falinHost, connect.WithProtoJSON()), convHistory: make(map[string]state), } diff --git a/cmd/mimi/modules/discord/jufra/tools.go b/cmd/mimi/modules/discord/jufra/tools.go index 8b7d674..46062fe 100644 --- a/cmd/mimi/modules/discord/jufra/tools.go +++ b/cmd/mimi/modules/discord/jufra/tools.go @@ -12,9 +12,9 @@ import ( "path/filepath" "time" - "github.com/google/uuid" + "connectrpc.com/connect" "within.website/x/llm/codeinterpreter/python" - "within.website/x/web/flux" + falin "within.website/x/migroserbices/falin/gen" "within.website/x/web/ollama" ) @@ -163,30 +163,22 @@ func (m *Module) eventuallySendImage(ctx context.Context, channelID string, prom } defer os.RemoveAll(tempDir) - pr, err := m.flux.PredictIdempotent(uuid.NewString(), flux.PredictionRequest{ - Input: flux.Input{ - Prompt: "an anime depiction of " + prompt, - AspectRatio: "16:9", - NumInferenceSteps: 50, - GuidanceScale: 3.5, - OutputFormat: "webp", - NumOutputs: 1, - MaxSequenceLength: 512, - OutputQuality: 95, - Seed: &[]int{420}[0], - }, - }) + ir, err := m.falin.GenerateImage(ctx, connect.NewRequest(&falin.GenerateImageRequest{ + Prompt: "an anime depiction of " + prompt, + Model: *falinModel, + NumImages: 1, + })) if err != nil { return fmt.Errorf("failed to predict: %w", err) } - resp, err := http.Get(pr.Output[0]) + resp, err := http.Get(ir.Msg.Images[0].Url) if err != nil { return fmt.Errorf("failed to get image: %w", err) } defer resp.Body.Close() - imgPath := filepath.Join(tempDir, "image.webp") + imgPath := filepath.Join(tempDir, "image.jpg") imgFile, err := os.Create(imgPath) if err != nil { return fmt.Errorf("failed to create image file: %w", err) @@ -200,7 +192,7 @@ func (m *Module) eventuallySendImage(ctx context.Context, channelID string, prom return fmt.Errorf("failed to seek image file: %w", err) } - msg, err := m.sess.ChannelFileSendWithMessage(channelID, "Here's the image!\n\n```"+prompt+"\n```", "image.webp", imgFile) + msg, err := m.sess.ChannelFileSendWithMessage(channelID, "Here's the image!\n\n```"+prompt+"\n```", "image.jpg", imgFile) if err != nil { return fmt.Errorf("failed to send image: %w", err) } |
