aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2023-12-09 11:13:00 -0500
committerXe Iaso <me@xeiaso.net>2023-12-09 11:13:00 -0500
commit030c5b365243f84a4a30f9a7030d588fa02ac954 (patch)
tree594f6f2354799de55d05f0a6e9e11d89a40f5919
parent1403d6130d772b05fdd8807f9aa09474e01051eb (diff)
downloadx-030c5b365243f84a4a30f9a7030d588fa02ac954.tar.xz
x-030c5b365243f84a4a30f9a7030d588fa02ac954.zip
cmd/mimi: add llava support
Signed-off-by: Xe Iaso <me@xeiaso.net>
-rw-r--r--cmd/mimi/main.go40
-rw-r--r--llm/llava/llava.go162
2 files changed, 202 insertions, 0 deletions
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go
index 2b3d640..6d9fb62 100644
--- a/cmd/mimi/main.go
+++ b/cmd/mimi/main.go
@@ -7,6 +7,7 @@ import (
"fmt"
"log"
"log/slog"
+ "net/http"
"os"
"os/signal"
"strings"
@@ -18,6 +19,7 @@ import (
"within.website/x/internal"
"within.website/x/llm"
"within.website/x/llm/llamaguard"
+ "within.website/x/llm/llava"
)
var (
@@ -26,6 +28,7 @@ var (
discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild")
discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel")
llamaguardHost = flag.String("llamaguard-host", "http://ontos:11434", "llamaguard host")
+ llavaHost = flag.String("llava-host", "http://localhost:8080", "llava host")
ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag")
ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host")
openAIKey = flag.String("openai-api-key", "", "openai key")
@@ -119,6 +122,43 @@ func main() {
}
}
+ if len(m.Attachments) > 0 {
+ for i, a := range m.Attachments {
+ switch a.ContentType {
+ case "image/png", "image/jpeg", "image/gif":
+ default:
+ continue
+ }
+
+ resp, err := http.Get(a.URL)
+ if err != nil {
+ slog.Error("http get error", "error", err)
+ continue
+ }
+ defer resp.Body.Close()
+
+ lrq, err := llava.DefaultRequest(m.Content, resp.Body)
+ if err != nil {
+ slog.Error("llava error", "error", err)
+ continue
+ }
+
+ lresp, err := llava.Describe(context.Background(), *llavaHost+"/completion", lrq)
+ if err != nil {
+ slog.Error("llava error", "error", err)
+ continue
+ }
+
+ if err := json.NewEncoder(&prompt).Encode(map[string]any{
+ "image": i,
+ "desc": lresp.Content,
+ }); err != nil {
+ slog.Error("json encode error", "error", err)
+ continue
+ }
+ }
+ }
+
lock.Lock()
defer lock.Unlock()
diff --git a/llm/llava/llava.go b/llm/llava/llava.go
new file mode 100644
index 0000000..4b2b7b5
--- /dev/null
+++ b/llm/llava/llava.go
@@ -0,0 +1,162 @@
+package llava
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "strconv"
+ "sync"
+
+ "within.website/x/web"
+)
+
+type Image struct {
+ Data []byte `json:"data"`
+ ID int `json:"id"`
+}
+
+type Request struct {
+ Stream bool `json:"stream"`
+ NPredict int `json:"n_predict"`
+ Temperature float64 `json:"temperature"`
+ Stop []string `json:"stop"`
+ RepeatLastN int `json:"repeat_last_n"`
+ RepeatPenalty float64 `json:"repeat_penalty"`
+ TopK int `json:"top_k"`
+ TopP float64 `json:"top_p"`
+ TfsZ int `json:"tfs_z"`
+ TypicalP int `json:"typical_p"`
+ PresencePenalty int `json:"presence_penalty"`
+ FrequencyPenalty int `json:"frequency_penalty"`
+ Mirostat int `json:"mirostat"`
+ MirostatTau int `json:"mirostat_tau"`
+ MirostatEta float64 `json:"mirostat_eta"`
+ Grammar string `json:"grammar"`
+ NProbs int `json:"n_probs"`
+ ImageData []Image `json:"image_data"`
+ CachePrompt bool `json:"cache_prompt"`
+ SlotID int `json:"slot_id"`
+ Prompt string `json:"prompt"`
+}
+
+var imageID = 10
+var imageLock = sync.Mutex{}
+
+func DefaultRequest(prompt string, image io.Reader) (*Request, error) {
+ imageLock.Lock()
+ defer imageLock.Unlock()
+
+ imageID++
+
+ imageData, err := io.ReadAll(image)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Request{
+ Stream: false,
+ NPredict: 400,
+ Temperature: 0.7,
+ Stop: []string{"</s>", "Mimi:", "User:"},
+ RepeatLastN: 256,
+ RepeatPenalty: 1.18,
+ TopK: 40,
+ TopP: 0.5,
+ TfsZ: 1,
+ TypicalP: 1,
+ PresencePenalty: 0,
+ FrequencyPenalty: 0,
+ Mirostat: 0,
+ MirostatTau: 5,
+ MirostatEta: 0.1,
+ Grammar: "",
+ NProbs: 0,
+ ImageData: []Image{
+ {
+ Data: imageData,
+ ID: imageID,
+ },
+ },
+ CachePrompt: true,
+ SlotID: -1,
+ Prompt: formatPrompt(prompt, imageID),
+ }, nil
+}
+
+func Describe(ctx context.Context, server string, req *Request) (*Response, error) {
+ var buf bytes.Buffer
+
+ if err := json.NewEncoder(&buf).Encode(req); err != nil {
+ return nil, err
+ }
+
+ r, err := http.NewRequestWithContext(ctx, http.MethodPost, server, &buf)
+ if err != nil {
+ return nil, err
+ }
+
+ r.Header.Set("Content-Type", "application/json")
+ r.Header.Set("Accept", "application/json")
+ r.Header.Set("User-Agent", "within.website/x/llm/llava")
+
+ resp, err := http.DefaultClient.Do(r)
+ if err != nil {
+ return nil, fmt.Errorf("llava: http request error: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var llr Response
+ if err := json.NewDecoder(resp.Body).Decode(&llr); err != nil {
+ return nil, fmt.Errorf("llava: json decode error: %w", err)
+ }
+
+ return &llr, nil
+}
+
+func formatPrompt(prompt string, imageID int) string {
+ const basePrompt = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
+ USER:[img-${imageID}]${prompt}
+ ASSISTANT:`
+ return os.Expand(basePrompt, func(key string) string {
+ switch key {
+ case "prompt":
+ return prompt
+ case "imageID":
+ return strconv.Itoa(imageID)
+ default:
+ return ""
+ }
+ })
+}
+
+type Response struct {
+ Content string `json:"content"`
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ SlotID int `json:"slot_id"`
+ Stop bool `json:"stop"`
+ Timings Timings `json:"timings"`
+ TokensCached int `json:"tokens_cached"`
+ TokensEvaluated int `json:"tokens_evaluated"`
+ TokensPredicted int `json:"tokens_predicted"`
+ Truncated bool `json:"truncated"`
+}
+
+type Timings struct {
+ PredictedMs float64 `json:"predicted_ms"`
+ PredictedN int `json:"predicted_n"`
+ PredictedPerSecond float64 `json:"predicted_per_second"`
+ PredictedPerTokenMs float64 `json:"predicted_per_token_ms"`
+ PromptMs float64 `json:"prompt_ms"`
+ PromptN int `json:"prompt_n"`
+ PromptPerSecond float64 `json:"prompt_per_second"`
+ PromptPerTokenMs float64 `json:"prompt_per_token_ms"`
+}