diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-12-07 19:13:29 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-12-07 19:13:50 -0500 |
| commit | bfd9d18254891113b8ecd2d5fa86a32c0744711f (patch) | |
| tree | 949802f0b9a7dd0922393307414fc751bdb5e873 /cmd | |
| parent | 31b52b72c3c3ebbc1f04288715bfb5f58b6f2de1 (diff) | |
| download | x-bfd9d18254891113b8ecd2d5fa86a32c0744711f.tar.xz x-bfd9d18254891113b8ecd2d5fa86a32c0744711f.zip | |
cmd: add mimi
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/mimi/main.go | 198 | ||||
| -rw-r--r-- | cmd/mimi/ollama/client.go | 311 | ||||
| -rw-r--r-- | cmd/mimi/ollama/client_test.go | 43 | ||||
| -rw-r--r-- | cmd/mimi/ollama/types.go | 362 | ||||
| -rw-r--r-- | cmd/mimi/var/.gitignore | 2 |
5 files changed, 916 insertions, 0 deletions
diff --git a/cmd/mimi/main.go b/cmd/mimi/main.go new file mode 100644 index 0000000..fc72fa7 --- /dev/null +++ b/cmd/mimi/main.go @@ -0,0 +1,198 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "log/slog" + "os" + "os/signal" + "strings" + "sync" + "syscall" + + "github.com/bwmarrin/discordgo" + "within.website/x/cmd/mimi/ollama" + "within.website/x/internal" + "within.website/x/llm" +) + +var ( + dataDir = flag.String("data-dir", "./var", "data directory for the bot") + discordToken = flag.String("discord-token", "", "discord token") + discordGuild = flag.String("discord-guild", "192289762302754817", "discord guild") + discordChannel = flag.String("discord-channel", "217096701771513856", "discord channel") + ollamaModel = flag.String("ollama-model", "xe/mimi:f16", "ollama model tag") + ollamaHost = flag.String("ollama-host", "http://kaine:11434", "ollama host") + openAIKey = flag.String("openai-api-key", "", "openai key") + openAITTSModel = flag.String("openai-tts-model", "nova", "openai tts model") +) + +func p[T any](t T) *T { + return &t +} + +func main() { + internal.HandleStartup() + + os.Setenv("OLLAMA_HOST", *ollamaHost) + + cli, err := ollama.ClientFromEnvironment() + if err != nil { + log.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if err := cli.Pull(ctx, + &ollama.PullRequest{ + Name: *ollamaModel, + Stream: p(true), + }, + func(pr ollama.ProgressResponse) error { + slog.Debug("pull progress", "progress", pr.Total-pr.Completed, "total", pr.Total) + return nil + }, + ); err != nil { + log.Fatal(err) + } + + dg, err := discordgo.New("Bot " + *discordToken) + if err != nil { + log.Fatal(err) + } + defer dg.Close() + + dg.AddHandler(func(s *discordgo.Session, m *discordgo.MessageCreate) { + if m.Author.ID == s.State.User.ID { + return + } + + if m.GuildID != *discordGuild { + return + } + + if m.ChannelID != *discordChannel { + return + } + + if m.Author.Bot { + return + } + + if m.Content == "!mimi" { + s.ChannelMessageSend(m.ChannelID, "mimi!") + return + } + + if m.Content == "!mimi clear" { + lock.Lock() + delete(stateMap, m.ChannelID) + lock.Unlock() + s.ChannelMessageSend(m.ChannelID, "mimi state cleared") + return + } + + var sb strings.Builder + var prompt strings.Builder + + if ns, ok := ParseNameslash(m.Content); ok { + if err := json.NewEncoder(&prompt).Encode(map[string]any{ + "message": ns.Message, + "user": ns.Name, + "is_admin": m.Author.Username == "xeiaso", + }); err != nil { + slog.Error("json encode error", "error", err) + } + } else { + if err := json.NewEncoder(&prompt).Encode(map[string]any{ + "message": m.Content, + "user": m.Author.Username, + "is_admin": m.Author.Username == "xeiaso", + }); err != nil { + slog.Error("json encode error", "error", err) + } + } + + lock.Lock() + defer lock.Unlock() + + st, ok := stateMap[m.ChannelID] + if !ok { + st = &State{ + Messages: []llm.Message{{ + Role: "user", + Content: prompt.String(), + }}, + } + + stateMap[m.ChannelID] = st + } + + fmt.Println(Prompt(st.Messages)) + + err = cli.Generate(ctx, + &ollama.GenerateRequest{ + Model: *ollamaModel, + Context: st.Context, + Prompt: prompt.String(), + Stream: p(true), + System: "Your name is Mimi. You will answer questions from users when asked. You are an expert in programming and philosophy. You are a catgirl. You are relaxed, terse, and casual. Twilight Sparkle is best pony.", + }, func(gr ollama.GenerateResponse) error { + fmt.Fprint(&sb, gr.Response) + + if gr.Done { + st.Context = gr.Context + st.Messages = append(st.Messages, llm.Message{ + Role: "assistant", + Content: gr.Response, + }) + } + return nil + }, + ) + + if err != nil { + slog.Error("generate error", "error", err, "channel", m.ChannelID) + return + } + + if _, err := s.ChannelMessageSend(m.ChannelID, sb.String()); err != nil { + slog.Error("message send error", "err", err, "message", sb.String()) + } + slog.Debug("context length", "len", len(st.Context)) + }) + + if err := dg.Open(); err != nil { + log.Fatal(err) + } + + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + <-sc + cancel() +} + +var lock sync.Mutex +var stateMap = map[string]*State{} + +type State struct { + Context []int + Messages []llm.Message +} + +type Nameslash struct { + Name string `json:"name"` + Message string `json:"message"` +} + +func ParseNameslash(msg string) (Nameslash, bool) { + parts := strings.Split(msg, "\\") + if len(parts) != 2 { + return Nameslash{}, false + } + return Nameslash{parts[0], parts[1]}, true +} diff --git a/cmd/mimi/ollama/client.go b/cmd/mimi/ollama/client.go new file mode 100644 index 0000000..439d89e --- /dev/null +++ b/cmd/mimi/ollama/client.go @@ -0,0 +1,311 @@ +package ollama + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + + "within.website/x/web/useragent" +) + +type Client struct { + base *url.URL + http http.Client +} + +func checkError(resp *http.Response, body []byte) error { + if resp.StatusCode < http.StatusBadRequest { + return nil + } + + apiError := StatusError{StatusCode: resp.StatusCode} + + err := json.Unmarshal(body, &apiError) + if err != nil { + // Use the full body as the message if we fail to decode a response. + apiError.ErrorMessage = string(body) + } + + return apiError +} + +func ClientFromEnvironment() (*Client, error) { + defaultPort := "11434" + + scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") + switch { + case !ok: + scheme, hostport = "http", os.Getenv("OLLAMA_HOST") + case scheme == "http": + defaultPort = "80" + case scheme == "https": + defaultPort = "443" + } + + // trim trailing slashes + hostport = strings.TrimRight(hostport, "/") + + host, port, err := net.SplitHostPort(hostport) + if err != nil { + host, port = "127.0.0.1", defaultPort + if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil { + host = ip.String() + } else if hostport != "" { + host = hostport + } + } + + client := Client{ + base: &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(host, port), + }, + } + + mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil) + if err != nil { + return nil, err + } + + proxyURL, err := http.ProxyFromEnvironment(mockRequest) + if err != nil { + return nil, err + } + + client.http = http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + } + + return &client, nil +} + +func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { + var reqBody io.Reader + var data []byte + var err error + + switch reqData := reqData.(type) { + case io.Reader: + // reqData is already an io.Reader + reqBody = reqData + case nil: + // noop + default: + data, err = json.Marshal(reqData) + if err != nil { + return err + } + + reqBody = bytes.NewReader(data) + } + + requestURL := c.base.JoinPath(path) + request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody) + if err != nil { + return err + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") + request.Header.Set("User-Agent", useragent.GenUserAgent("ollama", "https://xeiaso.net")) + + respObj, err := c.http.Do(request) + if err != nil { + return err + } + defer respObj.Body.Close() + + respBody, err := io.ReadAll(respObj.Body) + if err != nil { + return err + } + + if err := checkError(respObj, respBody); err != nil { + return err + } + + if len(respBody) > 0 && respData != nil { + if err := json.Unmarshal(respBody, respData); err != nil { + return err + } + } + return nil +} + +const maxBufferSize = 512 * 1024 + +func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { + var buf *bytes.Buffer + if data != nil { + bts, err := json.Marshal(data) + if err != nil { + return err + } + + buf = bytes.NewBuffer(bts) + } + + requestURL := c.base.JoinPath(path) + request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf) + if err != nil { + return err + } + + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/x-ndjson") + request.Header.Set("User-Agent", useragent.GenUserAgent("ollama", "https://xeiaso.net")) + + response, err := c.http.Do(request) + if err != nil { + return err + } + defer response.Body.Close() + + scanner := bufio.NewScanner(response.Body) + // increase the buffer size to avoid running out of space + scanBuf := make([]byte, 0, maxBufferSize) + scanner.Buffer(scanBuf, maxBufferSize) + for scanner.Scan() { + var errorResponse struct { + Error string `json:"error,omitempty"` + } + + bts := scanner.Bytes() + if err := json.Unmarshal(bts, &errorResponse); err != nil { + return fmt.Errorf("unmarshal: %w", err) + } + + if errorResponse.Error != "" { + return fmt.Errorf(errorResponse.Error) + } + + if response.StatusCode >= http.StatusBadRequest { + return StatusError{ + StatusCode: response.StatusCode, + Status: response.Status, + ErrorMessage: errorResponse.Error, + } + } + + if err := fn(bts); err != nil { + return err + } + } + + return nil +} + +type GenerateResponseFunc func(GenerateResponse) error + +func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error { + return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error { + var resp GenerateResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + +type PullProgressFunc func(ProgressResponse) error + +func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { + var resp ProgressResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + +type PushProgressFunc func(ProgressResponse) error + +func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error { + var resp ProgressResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + +type CreateProgressFunc func(ProgressResponse) error + +func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error { + var resp ProgressResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }) +} + +func (c *Client) List(ctx context.Context) (*ListResponse, error) { + var lr ListResponse + if err := c.do(ctx, http.MethodGet, "/api/tags", nil, &lr); err != nil { + return nil, err + } + return &lr, nil +} + +func (c *Client) Copy(ctx context.Context, req *CopyRequest) error { + if err := c.do(ctx, http.MethodPost, "/api/copy", req, nil); err != nil { + return err + } + return nil +} + +func (c *Client) Delete(ctx context.Context, req *DeleteRequest) error { + if err := c.do(ctx, http.MethodDelete, "/api/delete", req, nil); err != nil { + return err + } + return nil +} + +func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, error) { + var resp ShowResponse + if err := c.do(ctx, http.MethodPost, "/api/show", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (c *Client) Heartbeat(ctx context.Context) error { + if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { + return err + } + return nil +} + +func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error { + if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil { + var statusError StatusError + if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound { + return err + } + + if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil { + return err + } + } + + return nil +} diff --git a/cmd/mimi/ollama/client_test.go b/cmd/mimi/ollama/client_test.go new file mode 100644 index 0000000..9c06d3f --- /dev/null +++ b/cmd/mimi/ollama/client_test.go @@ -0,0 +1,43 @@ +package ollama + +import "testing" + +func TestClientFromEnvironment(t *testing.T) { + type testCase struct { + value string + expect string + err error + } + + testCases := map[string]*testCase{ + "empty": {value: "", expect: "http://127.0.0.1:11434"}, + "only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"}, + "only port": {value: ":1234", expect: "http://:1234"}, + "address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"}, + "scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"}, + "scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"}, + "scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"}, + "hostname": {value: "example.com", expect: "http://example.com:11434"}, + "hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"}, + "scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"}, + "scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"}, + "scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"}, + "trailing slash": {value: "example.com/", expect: "http://example.com:11434"}, + "trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"}, + } + + for k, v := range testCases { + t.Run(k, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", v.value) + + client, err := ClientFromEnvironment() + if err != v.err { + t.Fatalf("expected %s, got %s", v.err, err) + } + + if client.base.String() != v.expect { + t.Fatalf("expected %s, got %s", v.expect, client.base.String()) + } + }) + } +} diff --git a/cmd/mimi/ollama/types.go b/cmd/mimi/ollama/types.go new file mode 100644 index 0000000..9c5b16b --- /dev/null +++ b/cmd/mimi/ollama/types.go @@ -0,0 +1,362 @@ +package ollama + +import ( + "encoding/json" + "fmt" + "math" + "os" + "reflect" + "strings" + "time" +) + +type StatusError struct { + StatusCode int + Status string + ErrorMessage string `json:"error"` +} + +func (e StatusError) Error() string { + switch { + case e.Status != "" && e.ErrorMessage != "": + return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage) + case e.Status != "": + return e.Status + case e.ErrorMessage != "": + return e.ErrorMessage + default: + // this should not happen + return "something went wrong, please see the ollama server logs for details" + } +} + +type GenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + System string `json:"system"` + Template string `json:"template"` + Context []int `json:"context,omitempty"` + Stream *bool `json:"stream,omitempty"` + Raw bool `json:"raw,omitempty"` + Format string `json:"format"` + + Options map[string]interface{} `json:"options"` +} + +// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also +type Options struct { + Runner + + // Predict options used at runtime + NumKeep int `json:"num_keep,omitempty"` + Seed int `json:"seed,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + TFSZ float32 `json:"tfs_z,omitempty"` + TypicalP float32 `json:"typical_p,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Mirostat int `json:"mirostat,omitempty"` + MirostatTau float32 `json:"mirostat_tau,omitempty"` + MirostatEta float32 `json:"mirostat_eta,omitempty"` + PenalizeNewline bool `json:"penalize_newline,omitempty"` + Stop []string `json:"stop,omitempty"` +} + +// Runner options which must be set when the model is loaded into memory +type Runner struct { + UseNUMA bool `json:"numa,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumBatch int `json:"num_batch,omitempty"` + NumGQA int `json:"num_gqa,omitempty"` + NumGPU int `json:"num_gpu,omitempty"` + MainGPU int `json:"main_gpu,omitempty"` + LowVRAM bool `json:"low_vram,omitempty"` + F16KV bool `json:"f16_kv,omitempty"` + LogitsAll bool `json:"logits_all,omitempty"` + VocabOnly bool `json:"vocab_only,omitempty"` + UseMMap bool `json:"use_mmap,omitempty"` + UseMLock bool `json:"use_mlock,omitempty"` + EmbeddingOnly bool `json:"embedding_only,omitempty"` + RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"` + RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` + NumThread int `json:"num_thread,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + + Options map[string]interface{} `json:"options"` +} + +type EmbeddingResponse struct { + Embedding []float64 `json:"embedding"` +} + +type CreateRequest struct { + Name string `json:"name"` + Path string `json:"path"` + Modelfile string `json:"modelfile"` + Stream *bool `json:"stream,omitempty"` +} + +type DeleteRequest struct { + Name string `json:"name"` +} + +type ShowRequest struct { + Name string `json:"name"` +} + +type ShowResponse struct { + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` +} + +type CopyRequest struct { + Source string `json:"source"` + Destination string `json:"destination"` +} + +type PullRequest struct { + Name string `json:"name"` + Insecure bool `json:"insecure,omitempty"` + Username string `json:"username"` + Password string `json:"password"` + Stream *bool `json:"stream,omitempty"` +} + +type ProgressResponse struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int64 `json:"total,omitempty"` + Completed int64 `json:"completed,omitempty"` +} + +type PushRequest struct { + Name string `json:"name"` + Insecure bool `json:"insecure,omitempty"` + Username string `json:"username"` + Password string `json:"password"` + Stream *bool `json:"stream,omitempty"` +} + +type ListResponse struct { + Models []ModelResponse `json:"models"` +} + +type ModelResponse struct { + Name string `json:"name"` + ModifiedAt time.Time `json:"modified_at"` + Size int64 `json:"size"` + Digest string `json:"digest"` +} + +type TokenResponse struct { + Token string `json:"token"` +} + +type GenerateResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Response string `json:"response"` + + Done bool `json:"done"` + Context []int `json:"context,omitempty"` + + TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + +func (r *GenerateResponse) Summary() { + if r.TotalDuration > 0 { + fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) + } + + if r.LoadDuration > 0 { + fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) + } + + if r.PromptEvalCount > 0 { + fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) + } + + if r.PromptEvalDuration > 0 { + fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) + fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) + } + + if r.EvalCount > 0 { + fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount) + } + + if r.EvalDuration > 0 { + fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration) + fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) + } +} + +var ErrInvalidOpts = fmt.Errorf("invalid options") + +func (opts *Options) FromMap(m map[string]interface{}) error { + valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct + typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct + + // build map of json struct tags to their types + jsonOpts := make(map[string]reflect.StructField) + for _, field := range reflect.VisibleFields(typeOpts) { + jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] + if jsonTag != "" { + jsonOpts[jsonTag] = field + } + } + + invalidOpts := []string{} + for key, val := range m { + if opt, ok := jsonOpts[key]; ok { + field := valueOpts.FieldByName(opt.Name) + if field.IsValid() && field.CanSet() { + if val == nil { + continue + } + + switch field.Kind() { + case reflect.Int: + switch t := val.(type) { + case int64: + field.SetInt(t) + case float64: + // when JSON unmarshals numbers, it uses float64, not int + field.SetInt(int64(t)) + default: + return fmt.Errorf("option %q must be of type integer", key) + } + case reflect.Bool: + val, ok := val.(bool) + if !ok { + return fmt.Errorf("option %q must be of type boolean", key) + } + field.SetBool(val) + case reflect.Float32: + // JSON unmarshals to float64 + val, ok := val.(float64) + if !ok { + return fmt.Errorf("option %q must be of type float32", key) + } + field.SetFloat(val) + case reflect.String: + val, ok := val.(string) + if !ok { + return fmt.Errorf("option %q must be of type string", key) + } + field.SetString(val) + case reflect.Slice: + // JSON unmarshals to []interface{}, not []string + val, ok := val.([]interface{}) + if !ok { + return fmt.Errorf("option %q must be of type array", key) + } + // convert []interface{} to []string + slice := make([]string, len(val)) + for i, item := range val { + str, ok := item.(string) + if !ok { + return fmt.Errorf("option %q must be of an array of strings", key) + } + slice[i] = str + } + field.Set(reflect.ValueOf(slice)) + default: + return fmt.Errorf("unknown type loading config params: %v", field.Kind()) + } + } + } else { + invalidOpts = append(invalidOpts, key) + } + } + + if len(invalidOpts) > 0 { + return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", ")) + } + return nil +} + +func DefaultOptions() Options { + return Options{ + // options set on request to runner + NumPredict: -1, + NumKeep: 0, + Temperature: 0.8, + TopK: 40, + TopP: 0.9, + TFSZ: 1.0, + TypicalP: 1.0, + RepeatLastN: 64, + RepeatPenalty: 1.1, + PresencePenalty: 0.0, + FrequencyPenalty: 0.0, + Mirostat: 0, + MirostatTau: 5.0, + MirostatEta: 0.1, + PenalizeNewline: true, + Seed: -1, + + Runner: Runner{ + // options set when the model is loaded + NumCtx: 2048, + RopeFrequencyBase: 10000.0, + RopeFrequencyScale: 1.0, + NumBatch: 512, + NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically + NumGQA: 1, + NumThread: 0, // let the runtime decide + LowVRAM: false, + F16KV: true, + UseMLock: false, + UseMMap: true, + UseNUMA: false, + EmbeddingOnly: true, + }, + } +} + +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalJSON(b []byte) (err error) { + var v any + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + d.Duration = 5 * time.Minute + + switch t := v.(type) { + case float64: + if t < 0 { + t = math.MaxFloat64 + } + + d.Duration = time.Duration(t) + case string: + d.Duration, err = time.ParseDuration(t) + if err != nil { + return err + } + } + + return nil +} diff --git a/cmd/mimi/var/.gitignore b/cmd/mimi/var/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/cmd/mimi/var/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore
\ No newline at end of file |
