diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-12-09 11:13:00 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-12-09 11:13:00 -0500 |
| commit | 030c5b365243f84a4a30f9a7030d588fa02ac954 (patch) | |
| tree | 594f6f2354799de55d05f0a6e9e11d89a40f5919 | |
| parent | 1403d6130d772b05fdd8807f9aa09474e01051eb (diff) | |
| download | x-030c5b365243f84a4a30f9a7030d588fa02ac954.tar.xz x-030c5b365243f84a4a30f9a7030d588fa02ac954.zip | |
cmd/mimi: add llava support
Signed-off-by: Xe Iaso <me@xeiaso.net>
| -rw-r--r-- | cmd/mimi/main.go | 40 | ||||
| -rw-r--r-- | llm/llava/llava.go | 162 |
2 files changed, 202 insertions, 0 deletions
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go index 2b3d640..6d9fb62 100644 --- a/cmd/mimi/main.go +++ b/cmd/mimi/main.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "log/slog" + "net/http" "os" "os/signal" "strings" @@ -18,6 +19,7 @@ import ( "within.website/x/internal" "within.website/x/llm" "within.website/x/llm/llamaguard" + "within.website/x/llm/llava" ) var ( @@ -26,6 +28,7 @@ var ( discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild") discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel") llamaguardHost = flag.String("llamaguard-host", "http://ontos:11434", "llamaguard host") + llavaHost = flag.String("llava-host", "http://localhost:8080", "llava host") ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag") ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host") openAIKey = flag.String("openai-api-key", "", "openai key") @@ -119,6 +122,43 @@ func main() { } } + if len(m.Attachments) > 0 { + for i, a := range m.Attachments { + switch a.ContentType { + case "image/png", "image/jpeg", "image/gif": + default: + continue + } + + resp, err := http.Get(a.URL) + if err != nil { + slog.Error("http get error", "error", err) + continue + } + defer resp.Body.Close() + + lrq, err := llava.DefaultRequest(m.Content, resp.Body) + if err != nil { + slog.Error("llava error", "error", err) + continue + } + + lresp, err := llava.Describe(context.Background(), *llavaHost+"/completion", lrq) + if err != nil { + slog.Error("llava error", "error", err) + continue + } + + if err := json.NewEncoder(&prompt).Encode(map[string]any{ + "image": i, + "desc": lresp.Content, + }); err != nil { + slog.Error("json encode error", "error", err) + continue + } + } + } + lock.Lock() defer lock.Unlock() diff --git a/llm/llava/llava.go b/llm/llava/llava.go new file mode 100644 index 0000000..4b2b7b5 --- /dev/null +++ b/llm/llava/llava.go @@ -0,0 +1,162 @@ +package llava + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + "sync" + + "within.website/x/web" +) + +type Image struct { + Data []byte `json:"data"` + ID int `json:"id"` +} + +type Request struct { + Stream bool `json:"stream"` + NPredict int `json:"n_predict"` + Temperature float64 `json:"temperature"` + Stop []string `json:"stop"` + RepeatLastN int `json:"repeat_last_n"` + RepeatPenalty float64 `json:"repeat_penalty"` + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + TfsZ int `json:"tfs_z"` + TypicalP int `json:"typical_p"` + PresencePenalty int `json:"presence_penalty"` + FrequencyPenalty int `json:"frequency_penalty"` + Mirostat int `json:"mirostat"` + MirostatTau int `json:"mirostat_tau"` + MirostatEta float64 `json:"mirostat_eta"` + Grammar string `json:"grammar"` + NProbs int `json:"n_probs"` + ImageData []Image `json:"image_data"` + CachePrompt bool `json:"cache_prompt"` + SlotID int `json:"slot_id"` + Prompt string `json:"prompt"` +} + +var imageID = 10 +var imageLock = sync.Mutex{} + +func DefaultRequest(prompt string, image io.Reader) (*Request, error) { + imageLock.Lock() + defer imageLock.Unlock() + + imageID++ + + imageData, err := io.ReadAll(image) + if err != nil { + return nil, err + } + + return &Request{ + Stream: false, + NPredict: 400, + Temperature: 0.7, + Stop: []string{"</s>", "Mimi:", "User:"}, + RepeatLastN: 256, + RepeatPenalty: 1.18, + TopK: 40, + TopP: 0.5, + TfsZ: 1, + TypicalP: 1, + PresencePenalty: 0, + FrequencyPenalty: 0, + Mirostat: 0, + MirostatTau: 5, + MirostatEta: 0.1, + Grammar: "", + NProbs: 0, + ImageData: []Image{ + { + Data: imageData, + ID: imageID, + }, + }, + CachePrompt: true, + SlotID: -1, + Prompt: formatPrompt(prompt, imageID), + }, nil +} + +func Describe(ctx context.Context, server string, req *Request) (*Response, error) { + var buf bytes.Buffer + + if err := json.NewEncoder(&buf).Encode(req); err != nil { + return nil, err + } + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, server, &buf) + if err != nil { + return nil, err + } + + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Accept", "application/json") + r.Header.Set("User-Agent", "within.website/x/llm/llava") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + return nil, fmt.Errorf("llava: http request error: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var llr Response + if err := json.NewDecoder(resp.Body).Decode(&llr); err != nil { + return nil, fmt.Errorf("llava: json decode error: %w", err) + } + + return &llr, nil +} + +func formatPrompt(prompt string, imageID int) string { + const basePrompt = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. + USER:[img-${imageID}]${prompt} + ASSISTANT:` + return os.Expand(basePrompt, func(key string) string { + switch key { + case "prompt": + return prompt + case "imageID": + return strconv.Itoa(imageID) + default: + return "" + } + }) +} + +type Response struct { + Content string `json:"content"` + Model string `json:"model"` + Prompt string `json:"prompt"` + SlotID int `json:"slot_id"` + Stop bool `json:"stop"` + Timings Timings `json:"timings"` + TokensCached int `json:"tokens_cached"` + TokensEvaluated int `json:"tokens_evaluated"` + TokensPredicted int `json:"tokens_predicted"` + Truncated bool `json:"truncated"` +} + +type Timings struct { + PredictedMs float64 `json:"predicted_ms"` + PredictedN int `json:"predicted_n"` + PredictedPerSecond float64 `json:"predicted_per_second"` + PredictedPerTokenMs float64 `json:"predicted_per_token_ms"` + PromptMs float64 `json:"prompt_ms"` + PromptN int `json:"prompt_n"` + PromptPerSecond float64 `json:"prompt_per_second"` + PromptPerTokenMs float64 `json:"prompt_per_token_ms"` +} |
