diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-12-14 22:54:32 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-01-30 18:28:29 -0500 |
| commit | 15eb817e4ca36a6240b6beacbeff455fc7e78e3c (patch) | |
| tree | 83f963487e57658155026ddf9e522a19bda970ca | |
| parent | aa123ba1985912ec54211284c4ed4a569e1ae864 (diff) | |
| download | x-15eb817e4ca36a6240b6beacbeff455fc7e78e3c.tar.xz x-15eb817e4ca36a6240b6beacbeff455fc7e78e3c.zip | |
llm: add multillm package
Signed-off-by: Xe Iaso <me@xeiaso.net>
| -rw-r--r-- | llm/multillm/mistral.go | 42 | ||||
| -rw-r--r-- | llm/multillm/multillm.go | 36 | ||||
| -rw-r--r-- | llm/multillm/ollama.go | 44 | ||||
| -rw-r--r-- | llm/multillm/openai.go | 45 | ||||
| -rw-r--r-- | web/mistral/mistral.go | 14 | ||||
| -rw-r--r-- | web/ollama/ollama.go | 2 | ||||
| -rw-r--r-- | web/openai/chatgpt/chatgpt.go | 8 |
7 files changed, 187 insertions, 4 deletions
diff --git a/llm/multillm/mistral.go b/llm/multillm/mistral.go new file mode 100644 index 0000000..a950cf5 --- /dev/null +++ b/llm/multillm/mistral.go @@ -0,0 +1,42 @@ +package multillm + +import ( + "context" + + "within.website/x/llm" + "within.website/x/web/mistral" +) + +type Mistral struct { + *mistral.Client +} + +func (m *Mistral) Chat(ctx context.Context, req *Request) (*Response, error) { + cr := &mistral.CompleteRequest{ + Model: req.Model, + Messages: make([]llm.Message, len(req.Messages)), + Temperature: req.Temperature, + RandomSeed: req.RandomSeed, + } + + for i, m := range req.Messages { + cr.Messages[i] = llm.Message{ + Role: m.Role, + Content: m.Content, + } + } + + resp, err := m.Client.Chat(ctx, cr) + if err != nil { + return nil, err + } + + return &Response{ + Response: llm.Message{ + Role: resp.Choices[0].Message[0].Role, + Content: resp.Choices[0].Message[0].Content, + }, + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + }, nil +} diff --git a/llm/multillm/multillm.go b/llm/multillm/multillm.go new file mode 100644 index 0000000..9a56db1 --- /dev/null +++ b/llm/multillm/multillm.go @@ -0,0 +1,36 @@ +// Package multillm is a common interface for doing multiple large +// language model requests with common inputs and types. +package multillm + +import ( + "context" + + "within.website/x/llm" +) + +type Request struct { + Model string `json:"model"` + Messages []llm.Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + RandomSeed *int `json:"random_seed,omitempty"` +} + +type Response struct { + Response llm.Message `json:"response"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` +} + +type Chatter interface { + Chat(ctx context.Context, req *Request) (*Response, error) +} + +type MultiChatModel struct { + Provider string `json:"provider"` + Models []string `json:"models"` +} + +type MultiChatRequest struct { + Models []MultiChatModel `json:"models"` + Messages []llm.Message `json:"messages"` +} diff --git a/llm/multillm/ollama.go b/llm/multillm/ollama.go new file mode 100644 index 0000000..bc90033 --- /dev/null +++ b/llm/multillm/ollama.go @@ -0,0 +1,44 @@ +package multillm + +import ( + "context" + + "within.website/x/llm" + "within.website/x/web/ollama" +) + +type Ollama struct { + *ollama.Client +} + +func (o *Ollama) Chat(ctx context.Context, req *Request) (*Response, error) { + cr := &ollama.CompleteRequest{ + Model: req.Model, + Messages: make([]ollama.Message, len(req.Messages)), + Options: map[string]any{ + "temperature": req.Temperature, + "seed": req.RandomSeed, + }, + } + + for i, m := range req.Messages { + cr.Messages[i] = ollama.Message{ + Role: m.Role, + Content: m.Content, + } + } + + resp, err := o.Client.Chat(ctx, cr) + if err != nil { + return nil, err + } + + return &Response{ + Response: llm.Message{ + Role: resp.Message.Role, + Content: resp.Message.Content, + }, + PromptTokens: int(resp.PromptEvalCount), + CompletionTokens: int(resp.EvalCount), + }, nil +} diff --git a/llm/multillm/openai.go b/llm/multillm/openai.go new file mode 100644 index 0000000..a28f5d0 --- /dev/null +++ b/llm/multillm/openai.go @@ -0,0 +1,45 @@ +package multillm + +import ( + "context" + "fmt" + + "within.website/x/llm" + "within.website/x/web/openai/chatgpt" +) + +type OpenAI struct { + *chatgpt.Client +} + +func convertToChatGPTMessage(m llm.Message) chatgpt.Message { + return chatgpt.Message{ + Role: m.Role, + Content: m.Content, + } +} + +func (oaic *OpenAI) Chat(ctx context.Context, req *Request) (*Response, error) { + chatReq := chatgpt.Request{ + Model: req.Model, + Temperature: req.Temperature, + Seed: req.RandomSeed, + Messages: make([]chatgpt.Message, len(req.Messages)), + } + + for i, m := range req.Messages { + chatReq.Messages[i] = convertToChatGPTMessage(m) + } + + chatResp, err := oaic.Client.Complete(ctx, chatReq) + if err != nil { + return nil, fmt.Errorf("multillm: error chatting: %w", err) + } + + return &Response{ + Response: llm.Message{ + Role: chatResp.Choices[0].Message.Role, + Content: chatResp.Choices[0].Message.Content, + }, + }, nil +} diff --git a/web/mistral/mistral.go b/web/mistral/mistral.go index 3f2d9e7..75cf9d3 100644 --- a/web/mistral/mistral.go +++ b/web/mistral/mistral.go @@ -4,13 +4,27 @@ import ( "bytes" "context" "encoding/json" + "expvar" "fmt" "net/http" + "tailscale.com/metrics" "within.website/x/llm" "within.website/x/web" ) +var ( + promptTokens = metrics.LabelMap{Label: "model"} + completionTokens = metrics.LabelMap{Label: "model"} + totalTokens = metrics.LabelMap{Label: "model"} +) + +func init() { + expvar.Publish("gauge_x_web_mistral_prompt_tokens", &promptTokens) + expvar.Publish("gauge_x_web_mistral_completion_tokens", &completionTokens) + expvar.Publish("gauge_x_web_mistral_total_tokens", &totalTokens) +} + type Client struct { *http.Client apiKey string diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go index 0f479cc..eca0439 100644 --- a/web/ollama/ollama.go +++ b/web/ollama/ollama.go @@ -31,7 +31,7 @@ type Message struct { type CompleteRequest struct { Model string `json:"model"` - Messages Message `json:"messages"` + Messages []Message `json:"messages"` Format *string `json:"format,omitempty"` Template *string `json:"template,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/web/openai/chatgpt/chatgpt.go b/web/openai/chatgpt/chatgpt.go index 9f37524..ea375df 100644 --- a/web/openai/chatgpt/chatgpt.go +++ b/web/openai/chatgpt/chatgpt.go @@ -13,9 +13,11 @@ import ( ) type Request struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Functions []Function `json:"functions,omitempty"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Functions []Function `json:"functions,omitempty"` + Seed *int `json:"seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` } type Function struct { |
