aboutsummaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2023-12-07 19:13:29 -0500
committerXe Iaso <me@xeiaso.net>2023-12-07 19:13:50 -0500
commitbfd9d18254891113b8ecd2d5fa86a32c0744711f (patch)
tree949802f0b9a7dd0922393307414fc751bdb5e873 /cmd
parent31b52b72c3c3ebbc1f04288715bfb5f58b6f2de1 (diff)
downloadx-bfd9d18254891113b8ecd2d5fa86a32c0744711f.tar.xz
x-bfd9d18254891113b8ecd2d5fa86a32c0744711f.zip
cmd: add mimi
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
-rw-r--r--cmd/mimi/main.go198
-rw-r--r--cmd/mimi/ollama/client.go311
-rw-r--r--cmd/mimi/ollama/client_test.go43
-rw-r--r--cmd/mimi/ollama/types.go362
-rw-r--r--cmd/mimi/var/.gitignore2
5 files changed, 916 insertions, 0 deletions
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go
new file mode 100644
index 0000000..fc72fa7
--- /dev/null
+++ b/cmd/mimi/main.go
@@ -0,0 +1,198 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "log"
+ "log/slog"
+ "os"
+ "os/signal"
+ "strings"
+ "sync"
+ "syscall"
+
+ "github.com/bwmarrin/discordgo"
+ "within.website/x/cmd/mimi/ollama"
+ "within.website/x/internal"
+ "within.website/x/llm"
+)
+
+var (
+ dataDir = flag.String("data-dir", "./var", "data directory for the bot")
+ discordToken = flag.String("discord-token", "", "discord token")
+ discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild")
+ discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel")
+ ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag")
+ ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host")
+ openAIKey = flag.String("openai-api-key", "", "openai key")
+ openAITTSModel = flag.String("openai-tts-model", "nova", "openai tts model")
+)
+
+func p[T any](t T) *T {
+ return &t
+}
+
+func main() {
+ internal.HandleStartup()
+
+ os.Setenv("OLLAMA_HOST", *ollamaHost)
+
+ cli, err := ollama.ClientFromEnvironment()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ if err := cli.Pull(ctx,
+ &ollama.PullRequest{
+ Name: *ollamaModel,
+ Stream: p(true),
+ },
+ func(pr ollama.ProgressResponse) error {
+ slog.Debug("pull progress", "progress", pr.Total-pr.Completed, "total", pr.Total)
+ return nil
+ },
+ ); err != nil {
+ log.Fatal(err)
+ }
+
+ dg, err := discordgo.New("Bot " + *discordToken)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer dg.Close()
+
+ dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) {
+ if m.Author.ID == s.State.User.ID {
+ return
+ }
+
+ if m.GuildID != *discordGuild {
+ return
+ }
+
+ if m.ChannelID != *discordChannel {
+ return
+ }
+
+ if m.Author.Bot {
+ return
+ }
+
+ if m.Content == "!mimi" {
+ s.ChannelMessageSend(m.ChannelID, "mimi!")
+ return
+ }
+
+ if m.Content == "!mimi clear" {
+ lock.Lock()
+ delete(stateMap, m.ChannelID)
+ lock.Unlock()
+ s.ChannelMessageSend(m.ChannelID, "mimi state cleared")
+ return
+ }
+
+ var sb strings.Builder
+ var prompt strings.Builder
+
+ if ns, ok := ParseNameslash(m.Content); ok {
+ if err := json.NewEncoder(&prompt).Encode(map[string]any{
+ "message": ns.Message,
+ "user": ns.Name,
+ "is_admin": m.Author.Username == "xeiaso",
+ }); err != nil {
+ slog.Error("json encode error", "error", err)
+ }
+ } else {
+ if err := json.NewEncoder(&prompt).Encode(map[string]any{
+ "message": m.Content,
+ "user": m.Author.Username,
+ "is_admin": m.Author.Username == "xeiaso",
+ }); err != nil {
+ slog.Error("json encode error", "error", err)
+ }
+ }
+
+ lock.Lock()
+ defer lock.Unlock()
+
+ st, ok := stateMap[m.ChannelID]
+ if !ok {
+ st = &State{
+ Messages: []llm.Message{{
+ Role: "user",
+ Content: prompt.String(),
+ }},
+ }
+
+ stateMap[m.ChannelID] = st
+ }
+
+ fmt.Println(Prompt(st.Messages))
+
+ err = cli.Generate(ctx,
+ &ollama.GenerateRequest{
+ Model: *ollamaModel,
+ Context: st.Context,
+ Prompt: prompt.String(),
+ Stream: p(true),
+ System: "Your name is Mimi. You will answer questions from users when asked. You are an expert in programming and philosophy. You are a catgirl. You are relaxed, terse, and casual. Twilight Sparkle is best pony.",
+ }, func(gr ollama.GenerateResponse) error {
+ fmt.Fprint(&sb, gr.Response)
+
+ if gr.Done {
+ st.Context = gr.Context
+ st.Messages = append(st.Messages, llm.Message{
+ Role: "assistant",
+ Content: gr.Response,
+ })
+ }
+ return nil
+ },
+ )
+
+ if err != nil {
+ slog.Error("generate error", "error", err, "channel", m.ChannelID)
+ return
+ }
+
+ if _, err := s.ChannelMessageSend(m.ChannelID, sb.String()); err != nil {
+ slog.Error("message send error", "err", err, "message", sb.String())
+ }
+ slog.Debug("context length", "len", len(st.Context))
+ })
+
+ if err := dg.Open(); err != nil {
+ log.Fatal(err)
+ }
+
+ sc := make(chan os.Signal, 1)
+ signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
+ <-sc
+ cancel()
+}
+
+var lock sync.Mutex
+var stateMap = map[string]*State{}
+
+type State struct {
+ Context []int
+ Messages []llm.Message
+}
+
+type Nameslash struct {
+ Name string `json:"name"`
+ Message string `json:"message"`
+}
+
+func ParseNameslash(msg string) (Nameslash, bool) {
+ parts := strings.Split(msg, "\\")
+ if len(parts) != 2 {
+ return Nameslash{}, false
+ }
+ return Nameslash{parts[0], parts[1]}, true
+}
diff --git a/cmd/mimi/ollama/client.go b/cmd/mimi/ollama/client.go
new file mode 100644
index 0000000..439d89e
--- /dev/null
+++ b/cmd/mimi/ollama/client.go
@@ -0,0 +1,311 @@
+package ollama
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "strings"
+
+ "within.website/x/web/useragent"
+)
+
+type Client struct {
+ base *url.URL
+ http http.Client
+}
+
+func checkError(resp *http.Response, body []byte) error {
+ if resp.StatusCode < http.StatusBadRequest {
+ return nil
+ }
+
+ apiError := StatusError{StatusCode: resp.StatusCode}
+
+ err := json.Unmarshal(body, &apiError)
+ if err != nil {
+ // Use the full body as the message if we fail to decode a response.
+ apiError.ErrorMessage = string(body)
+ }
+
+ return apiError
+}
+
+func ClientFromEnvironment() (*Client, error) {
+ defaultPort := "11434"
+
+ scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
+ switch {
+ case !ok:
+ scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
+ case scheme == "http":
+ defaultPort = "80"
+ case scheme == "https":
+ defaultPort = "443"
+ }
+
+ // trim trailing slashes
+ hostport = strings.TrimRight(hostport, "/")
+
+ host, port, err := net.SplitHostPort(hostport)
+ if err != nil {
+ host, port = "127.0.0.1", defaultPort
+ if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
+ host = ip.String()
+ } else if hostport != "" {
+ host = hostport
+ }
+ }
+
+ client := Client{
+ base: &url.URL{
+ Scheme: scheme,
+ Host: net.JoinHostPort(host, port),
+ },
+ }
+
+ mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+
+ proxyURL, err := http.ProxyFromEnvironment(mockRequest)
+ if err != nil {
+ return nil, err
+ }
+
+ client.http = http.Client{
+ Transport: &http.Transport{
+ Proxy: http.ProxyURL(proxyURL),
+ },
+ }
+
+ return &client, nil
+}
+
+func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
+ var reqBody io.Reader
+ var data []byte
+ var err error
+
+ switch reqData := reqData.(type) {
+ case io.Reader:
+ // reqData is already an io.Reader
+ reqBody = reqData
+ case nil:
+ // noop
+ default:
+ data, err = json.Marshal(reqData)
+ if err != nil {
+ return err
+ }
+
+ reqBody = bytes.NewReader(data)
+ }
+
+ requestURL := c.base.JoinPath(path)
+ request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
+ if err != nil {
+ return err
+ }
+
+ request.Header.Set("Content-Type", "application/json")
+ request.Header.Set("Accept", "application/json")
+ request.Header.Set("User-Agent", useragent.GenUserAgent("ollama", "https://xeiaso.net"))
+
+ respObj, err := c.http.Do(request)
+ if err != nil {
+ return err
+ }
+ defer respObj.Body.Close()
+
+ respBody, err := io.ReadAll(respObj.Body)
+ if err != nil {
+ return err
+ }
+
+ if err := checkError(respObj, respBody); err != nil {
+ return err
+ }
+
+ if len(respBody) > 0 && respData != nil {
+ if err := json.Unmarshal(respBody, respData); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+const maxBufferSize = 512 * 1024
+
+func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
+ var buf *bytes.Buffer
+ if data != nil {
+ bts, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+
+ buf = bytes.NewBuffer(bts)
+ }
+
+ requestURL := c.base.JoinPath(path)
+ request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
+ if err != nil {
+ return err
+ }
+
+ request.Header.Set("Content-Type", "application/json")
+ request.Header.Set("Accept", "application/x-ndjson")
+ request.Header.Set("User-Agent", useragent.GenUserAgent("ollama", "https://xeiaso.net"))
+
+ response, err := c.http.Do(request)
+ if err != nil {
+ return err
+ }
+ defer response.Body.Close()
+
+ scanner := bufio.NewScanner(response.Body)
+ // increase the buffer size to avoid running out of space
+ scanBuf := make([]byte, 0, maxBufferSize)
+ scanner.Buffer(scanBuf, maxBufferSize)
+ for scanner.Scan() {
+ var errorResponse struct {
+ Error string `json:"error,omitempty"`
+ }
+
+ bts := scanner.Bytes()
+ if err := json.Unmarshal(bts, &errorResponse); err != nil {
+ return fmt.Errorf("unmarshal: %w", err)
+ }
+
+ if errorResponse.Error != "" {
+ return fmt.Errorf(errorResponse.Error)
+ }
+
+ if response.StatusCode >= http.StatusBadRequest {
+ return StatusError{
+ StatusCode: response.StatusCode,
+ Status: response.Status,
+ ErrorMessage: errorResponse.Error,
+ }
+ }
+
+ if err := fn(bts); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+type GenerateResponseFunc func(GenerateResponse) error
+
+func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
+ return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
+ var resp GenerateResponse
+ if err := json.Unmarshal(bts, &resp); err != nil {
+ return err
+ }
+
+ return fn(resp)
+ })
+}
+
+type PullProgressFunc func(ProgressResponse) error
+
+func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
+ return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
+ var resp ProgressResponse
+ if err := json.Unmarshal(bts, &resp); err != nil {
+ return err
+ }
+
+ return fn(resp)
+ })
+}
+
+type PushProgressFunc func(ProgressResponse) error
+
+func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
+ return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
+ var resp ProgressResponse
+ if err := json.Unmarshal(bts, &resp); err != nil {
+ return err
+ }
+
+ return fn(resp)
+ })
+}
+
+type CreateProgressFunc func(ProgressResponse) error
+
+func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
+ return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
+ var resp ProgressResponse
+ if err := json.Unmarshal(bts, &resp); err != nil {
+ return err
+ }
+
+ return fn(resp)
+ })
+}
+
+func (c *Client) List(ctx context.Context) (*ListResponse, error) {
+ var lr ListResponse
+ if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil {
+ return nil, err
+ }
+ return &lr, nil
+}
+
+func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
+ if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error {
+ if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) {
+ var resp ShowResponse
+ if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func (c *Client) Heartbeat(ctx context.Context) error {
+ if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
+ if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
+ var statusError StatusError
+ if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
+ return err
+ }
+
+ if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/cmd/mimi/ollama/client_test.go b/cmd/mimi/ollama/client_test.go
new file mode 100644
index 0000000..9c06d3f
--- /dev/null
+++ b/cmd/mimi/ollama/client_test.go
@@ -0,0 +1,43 @@
+package ollama
+
+import "testing"
+
+func TestClientFromEnvironment(t *testing.T) {
+ type testCase struct {
+ value string
+ expect string
+ err error
+ }
+
+ testCases := map[string]*testCase{
+ "empty": {value: "", expect: "http://127.0.0.1:11434"},
+ "only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
+ "only port": {value: ":1234", expect: "http://:1234"},
+ "address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
+ "scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
+ "scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
+ "scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
+ "hostname": {value: "example.com", expect: "http://example.com:11434"},
+ "hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
+ "scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
+ "scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
+ "scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
+ "trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
+ "trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
+ }
+
+ for k, v := range testCases {
+ t.Run(k, func(t *testing.T) {
+ t.Setenv("OLLAMA_HOST", v.value)
+
+ client, err := ClientFromEnvironment()
+ if err != v.err {
+ t.Fatalf("expected %s, got %s", v.err, err)
+ }
+
+ if client.base.String() != v.expect {
+ t.Fatalf("expected %s, got %s", v.expect, client.base.String())
+ }
+ })
+ }
+}
diff --git a/cmd/mimi/ollama/types.go b/cmd/mimi/ollama/types.go
new file mode 100644
index 0000000..9c5b16b
--- /dev/null
+++ b/cmd/mimi/ollama/types.go
@@ -0,0 +1,362 @@
+package ollama
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+ "os"
+ "reflect"
+ "strings"
+ "time"
+)
+
+type StatusError struct {
+ StatusCode int
+ Status string
+ ErrorMessage string `json:"error"`
+}
+
+func (e StatusError) Error() string {
+ switch {
+ case e.Status != "" && e.ErrorMessage != "":
+ return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
+ case e.Status != "":
+ return e.Status
+ case e.ErrorMessage != "":
+ return e.ErrorMessage
+ default:
+ // this should not happen
+ return "something went wrong, please see the ollama server logs for details"
+ }
+}
+
+type GenerateRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ System string `json:"system"`
+ Template string `json:"template"`
+ Context []int `json:"context,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ Raw bool `json:"raw,omitempty"`
+ Format string `json:"format"`
+
+ Options map[string]interface{} `json:"options"`
+}
+
+// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
+type Options struct {
+ Runner
+
+ // Predict options used at runtime
+ NumKeep int `json:"num_keep,omitempty"`
+ Seed int `json:"seed,omitempty"`
+ NumPredict int `json:"num_predict,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float32 `json:"top_p,omitempty"`
+ TFSZ float32 `json:"tfs_z,omitempty"`
+ TypicalP float32 `json:"typical_p,omitempty"`
+ RepeatLastN int `json:"repeat_last_n,omitempty"`
+ Temperature float32 `json:"temperature,omitempty"`
+ RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
+ PresencePenalty float32 `json:"presence_penalty,omitempty"`
+ FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
+ Mirostat int `json:"mirostat,omitempty"`
+ MirostatTau float32 `json:"mirostat_tau,omitempty"`
+ MirostatEta float32 `json:"mirostat_eta,omitempty"`
+ PenalizeNewline bool `json:"penalize_newline,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+}
+
+// Runner options which must be set when the model is loaded into memory
+type Runner struct {
+ UseNUMA bool `json:"numa,omitempty"`
+ NumCtx int `json:"num_ctx,omitempty"`
+ NumBatch int `json:"num_batch,omitempty"`
+ NumGQA int `json:"num_gqa,omitempty"`
+ NumGPU int `json:"num_gpu,omitempty"`
+ MainGPU int `json:"main_gpu,omitempty"`
+ LowVRAM bool `json:"low_vram,omitempty"`
+ F16KV bool `json:"f16_kv,omitempty"`
+ LogitsAll bool `json:"logits_all,omitempty"`
+ VocabOnly bool `json:"vocab_only,omitempty"`
+ UseMMap bool `json:"use_mmap,omitempty"`
+ UseMLock bool `json:"use_mlock,omitempty"`
+ EmbeddingOnly bool `json:"embedding_only,omitempty"`
+ RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
+ RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
+ NumThread int `json:"num_thread,omitempty"`
+}
+
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+
+ Options map[string]interface{} `json:"options"`
+}
+
+type EmbeddingResponse struct {
+ Embedding []float64 `json:"embedding"`
+}
+
+type CreateRequest struct {
+ Name string `json:"name"`
+ Path string `json:"path"`
+ Modelfile string `json:"modelfile"`
+ Stream *bool `json:"stream,omitempty"`
+}
+
+type DeleteRequest struct {
+ Name string `json:"name"`
+}
+
+type ShowRequest struct {
+ Name string `json:"name"`
+}
+
+type ShowResponse struct {
+ License string `json:"license,omitempty"`
+ Modelfile string `json:"modelfile,omitempty"`
+ Parameters string `json:"parameters,omitempty"`
+ Template string `json:"template,omitempty"`
+ System string `json:"system,omitempty"`
+}
+
+type CopyRequest struct {
+ Source string `json:"source"`
+ Destination string `json:"destination"`
+}
+
+type PullRequest struct {
+ Name string `json:"name"`
+ Insecure bool `json:"insecure,omitempty"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Stream *bool `json:"stream,omitempty"`
+}
+
+type ProgressResponse struct {
+ Status string `json:"status"`
+ Digest string `json:"digest,omitempty"`
+ Total int64 `json:"total,omitempty"`
+ Completed int64 `json:"completed,omitempty"`
+}
+
+type PushRequest struct {
+ Name string `json:"name"`
+ Insecure bool `json:"insecure,omitempty"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Stream *bool `json:"stream,omitempty"`
+}
+
+type ListResponse struct {
+ Models []ModelResponse `json:"models"`
+}
+
+type ModelResponse struct {
+ Name string `json:"name"`
+ ModifiedAt time.Time `json:"modified_at"`
+ Size int64 `json:"size"`
+ Digest string `json:"digest"`
+}
+
+type TokenResponse struct {
+ Token string `json:"token"`
+}
+
+type GenerateResponse struct {
+ Model string `json:"model"`
+ CreatedAt time.Time `json:"created_at"`
+ Response string `json:"response"`
+
+ Done bool `json:"done"`
+ Context []int `json:"context,omitempty"`
+
+ TotalDuration time.Duration `json:"total_duration,omitempty"`
+ LoadDuration time.Duration `json:"load_duration,omitempty"`
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
+ PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
+ EvalCount int `json:"eval_count,omitempty"`
+ EvalDuration time.Duration `json:"eval_duration,omitempty"`
+}
+
+func (r *GenerateResponse) Summary() {
+ if r.TotalDuration > 0 {
+ fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
+ }
+
+ if r.LoadDuration > 0 {
+ fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
+ }
+
+ if r.PromptEvalCount > 0 {
+ fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
+ }
+
+ if r.PromptEvalDuration > 0 {
+ fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
+ fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
+ }
+
+ if r.EvalCount > 0 {
+ fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
+ }
+
+ if r.EvalDuration > 0 {
+ fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration)
+ fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
+ }
+}
+
+var ErrInvalidOpts = fmt.Errorf("invalid options")
+
+func (opts *Options) FromMap(m map[string]interface{}) error {
+ valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
+ typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
+
+ // build map of json struct tags to their types
+ jsonOpts := make(map[string]reflect.StructField)
+ for _, field := range reflect.VisibleFields(typeOpts) {
+ jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
+ if jsonTag != "" {
+ jsonOpts[jsonTag] = field
+ }
+ }
+
+ invalidOpts := []string{}
+ for key, val := range m {
+ if opt, ok := jsonOpts[key]; ok {
+ field := valueOpts.FieldByName(opt.Name)
+ if field.IsValid() && field.CanSet() {
+ if val == nil {
+ continue
+ }
+
+ switch field.Kind() {
+ case reflect.Int:
+ switch t := val.(type) {
+ case int64:
+ field.SetInt(t)
+ case float64:
+ // when JSON unmarshals numbers, it uses float64, not int
+ field.SetInt(int64(t))
+ default:
+ return fmt.Errorf("option %q must be of type integer", key)
+ }
+ case reflect.Bool:
+ val, ok := val.(bool)
+ if !ok {
+ return fmt.Errorf("option %q must be of type boolean", key)
+ }
+ field.SetBool(val)
+ case reflect.Float32:
+ // JSON unmarshals to float64
+ val, ok := val.(float64)
+ if !ok {
+ return fmt.Errorf("option %q must be of type float32", key)
+ }
+ field.SetFloat(val)
+ case reflect.String:
+ val, ok := val.(string)
+ if !ok {
+ return fmt.Errorf("option %q must be of type string", key)
+ }
+ field.SetString(val)
+ case reflect.Slice:
+ // JSON unmarshals to []interface{}, not []string
+ val, ok := val.([]interface{})
+ if !ok {
+ return fmt.Errorf("option %q must be of type array", key)
+ }
+ // convert []interface{} to []string
+ slice := make([]string, len(val))
+ for i, item := range val {
+ str, ok := item.(string)
+ if !ok {
+ return fmt.Errorf("option %q must be of an array of strings", key)
+ }
+ slice[i] = str
+ }
+ field.Set(reflect.ValueOf(slice))
+ default:
+ return fmt.Errorf("unknown type loading config params: %v", field.Kind())
+ }
+ }
+ } else {
+ invalidOpts = append(invalidOpts, key)
+ }
+ }
+
+ if len(invalidOpts) > 0 {
+ return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
+ }
+ return nil
+}
+
+func DefaultOptions() Options {
+ return Options{
+ // options set on request to runner
+ NumPredict: -1,
+ NumKeep: 0,
+ Temperature: 0.8,
+ TopK: 40,
+ TopP: 0.9,
+ TFSZ: 1.0,
+ TypicalP: 1.0,
+ RepeatLastN: 64,
+ RepeatPenalty: 1.1,
+ PresencePenalty: 0.0,
+ FrequencyPenalty: 0.0,
+ Mirostat: 0,
+ MirostatTau: 5.0,
+ MirostatEta: 0.1,
+ PenalizeNewline: true,
+ Seed: -1,
+
+ Runner: Runner{
+ // options set when the model is loaded
+ NumCtx: 2048,
+ RopeFrequencyBase: 10000.0,
+ RopeFrequencyScale: 1.0,
+ NumBatch: 512,
+ NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
+ NumGQA: 1,
+ NumThread: 0, // let the runtime decide
+ LowVRAM: false,
+ F16KV: true,
+ UseMLock: false,
+ UseMMap: true,
+ UseNUMA: false,
+ EmbeddingOnly: true,
+ },
+ }
+}
+
+type Duration struct {
+ time.Duration
+}
+
+func (d *Duration) UnmarshalJSON(b []byte) (err error) {
+ var v any
+ if err := json.Unmarshal(b, &v); err != nil {
+ return err
+ }
+
+ d.Duration = 5 * time.Minute
+
+ switch t := v.(type) {
+ case float64:
+ if t < 0 {
+ t = math.MaxFloat64
+ }
+
+ d.Duration = time.Duration(t)
+ case string:
+ d.Duration, err = time.ParseDuration(t)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/cmd/mimi/var/.gitignore b/cmd/mimi/var/.gitignore
new file mode 100644
index 0000000..c96a04f
--- /dev/null
+++ b/cmd/mimi/var/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore \ No newline at end of file