aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2023-12-14 22:54:32 -0500
committerXe Iaso <me@xeiaso.net>2024-01-30 18:28:29 -0500
commit15eb817e4ca36a6240b6beacbeff455fc7e78e3c (patch)
tree83f963487e57658155026ddf9e522a19bda970ca
parentaa123ba1985912ec54211284c4ed4a569e1ae864 (diff)
downloadx-15eb817e4ca36a6240b6beacbeff455fc7e78e3c.tar.xz
x-15eb817e4ca36a6240b6beacbeff455fc7e78e3c.zip
llm: add multillm package
Signed-off-by: Xe Iaso <me@xeiaso.net>
-rw-r--r--llm/multillm/mistral.go42
-rw-r--r--llm/multillm/multillm.go36
-rw-r--r--llm/multillm/ollama.go44
-rw-r--r--llm/multillm/openai.go45
-rw-r--r--web/mistral/mistral.go14
-rw-r--r--web/ollama/ollama.go2
-rw-r--r--web/openai/chatgpt/chatgpt.go8
7 files changed, 187 insertions, 4 deletions
diff --git a/llm/multillm/mistral.go b/llm/multillm/mistral.go
new file mode 100644
index 0000000..a950cf5
--- /dev/null
+++ b/llm/multillm/mistral.go
@@ -0,0 +1,42 @@
+package multillm
+
+import (
+ "context"
+
+ "within.website/x/llm"
+ "within.website/x/web/mistral"
+)
+
+type Mistral struct {
+ *mistral.Client
+}
+
+func (m *Mistral) Chat(ctx context.Context, req *Request) (*Response, error) {
+ cr := &mistral.CompleteRequest{
+ Model: req.Model,
+ Messages: make([]llm.Message, len(req.Messages)),
+ Temperature: req.Temperature,
+ RandomSeed: req.RandomSeed,
+ }
+
+ for i, m := range req.Messages {
+ cr.Messages[i] = llm.Message{
+ Role: m.Role,
+ Content: m.Content,
+ }
+ }
+
+ resp, err := m.Client.Chat(ctx, cr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Response{
+ Response: llm.Message{
+ Role: resp.Choices[0].Message[0].Role,
+ Content: resp.Choices[0].Message[0].Content,
+ },
+ PromptTokens: resp.Usage.PromptTokens,
+ CompletionTokens: resp.Usage.CompletionTokens,
+ }, nil
+}
diff --git a/llm/multillm/multillm.go b/llm/multillm/multillm.go
new file mode 100644
index 0000000..9a56db1
--- /dev/null
+++ b/llm/multillm/multillm.go
@@ -0,0 +1,36 @@
+// Package multillm is a common interface for doing multiple large
+// language model requests with common inputs and types.
+package multillm
+
+import (
+ "context"
+
+ "within.website/x/llm"
+)
+
+type Request struct {
+ Model string `json:"model"`
+ Messages []llm.Message `json:"messages"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ RandomSeed *int `json:"random_seed,omitempty"`
+}
+
+type Response struct {
+ Response llm.Message `json:"response"`
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+}
+
+type Chatter interface {
+ Chat(ctx context.Context, req *Request) (*Response, error)
+}
+
+type MultiChatModel struct {
+ Provider string `json:"provider"`
+ Models []string `json:"models"`
+}
+
+type MultiChatRequest struct {
+ Models []MultiChatModel `json:"models"`
+ Messages []llm.Message `json:"messages"`
+}
diff --git a/llm/multillm/ollama.go b/llm/multillm/ollama.go
new file mode 100644
index 0000000..bc90033
--- /dev/null
+++ b/llm/multillm/ollama.go
@@ -0,0 +1,44 @@
+package multillm
+
+import (
+ "context"
+
+ "within.website/x/llm"
+ "within.website/x/web/ollama"
+)
+
+type Ollama struct {
+ *ollama.Client
+}
+
+func (o *Ollama) Chat(ctx context.Context, req *Request) (*Response, error) {
+ cr := &ollama.CompleteRequest{
+ Model: req.Model,
+ Messages: make([]ollama.Message, len(req.Messages)),
+ Options: map[string]any{
+ "temperature": req.Temperature,
+ "seed": req.RandomSeed,
+ },
+ }
+
+ for i, m := range req.Messages {
+ cr.Messages[i] = ollama.Message{
+ Role: m.Role,
+ Content: m.Content,
+ }
+ }
+
+ resp, err := o.Client.Chat(ctx, cr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Response{
+ Response: llm.Message{
+ Role: resp.Message.Role,
+ Content: resp.Message.Content,
+ },
+ PromptTokens: int(resp.PromptEvalCount),
+ CompletionTokens: int(resp.EvalCount),
+ }, nil
+}
diff --git a/llm/multillm/openai.go b/llm/multillm/openai.go
new file mode 100644
index 0000000..a28f5d0
--- /dev/null
+++ b/llm/multillm/openai.go
@@ -0,0 +1,45 @@
+package multillm
+
+import (
+ "context"
+ "fmt"
+
+ "within.website/x/llm"
+ "within.website/x/web/openai/chatgpt"
+)
+
+type OpenAI struct {
+ *chatgpt.Client
+}
+
+func convertToChatGPTMessage(m llm.Message) chatgpt.Message {
+ return chatgpt.Message{
+ Role: m.Role,
+ Content: m.Content,
+ }
+}
+
+func (oaic *OpenAI) Chat(ctx context.Context, req *Request) (*Response, error) {
+ chatReq := chatgpt.Request{
+ Model: req.Model,
+ Temperature: req.Temperature,
+ Seed: req.RandomSeed,
+ Messages: make([]chatgpt.Message, len(req.Messages)),
+ }
+
+ for i, m := range req.Messages {
+ chatReq.Messages[i] = convertToChatGPTMessage(m)
+ }
+
+ chatResp, err := oaic.Client.Complete(ctx, chatReq)
+ if err != nil {
+ return nil, fmt.Errorf("multillm: error chatting: %w", err)
+ }
+
+ return &Response{
+ Response: llm.Message{
+ Role: chatResp.Choices[0].Message.Role,
+ Content: chatResp.Choices[0].Message.Content,
+ },
+ }, nil
+}
diff --git a/web/mistral/mistral.go b/web/mistral/mistral.go
index 3f2d9e7..75cf9d3 100644
--- a/web/mistral/mistral.go
+++ b/web/mistral/mistral.go
@@ -4,13 +4,27 @@ import (
"bytes"
"context"
"encoding/json"
+ "expvar"
"fmt"
"net/http"
+ "tailscale.com/metrics"
"within.website/x/llm"
"within.website/x/web"
)
+var (
+ promptTokens = metrics.LabelMap{Label: "model"}
+ completionTokens = metrics.LabelMap{Label: "model"}
+ totalTokens = metrics.LabelMap{Label: "model"}
+)
+
+func init() {
+ expvar.Publish("gauge_x_web_mistral_prompt_tokens", &promptTokens)
+ expvar.Publish("gauge_x_web_mistral_completion_tokens", &completionTokens)
+ expvar.Publish("gauge_x_web_mistral_total_tokens", &totalTokens)
+}
+
type Client struct {
*http.Client
apiKey string
diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go
index 0f479cc..eca0439 100644
--- a/web/ollama/ollama.go
+++ b/web/ollama/ollama.go
@@ -31,7 +31,7 @@ type Message struct {
type CompleteRequest struct {
Model string `json:"model"`
- Messages Message `json:"messages"`
+ Messages []Message `json:"messages"`
Format *string `json:"format,omitempty"`
Template *string `json:"template,omitempty"`
Stream bool `json:"stream,omitempty"`
diff --git a/web/openai/chatgpt/chatgpt.go b/web/openai/chatgpt/chatgpt.go
index 9f37524..ea375df 100644
--- a/web/openai/chatgpt/chatgpt.go
+++ b/web/openai/chatgpt/chatgpt.go
@@ -13,9 +13,11 @@ import (
)
type Request struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Functions []Function `json:"functions,omitempty"`
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ Functions []Function `json:"functions,omitempty"`
+ Seed *int `json:"seed,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
}
type Function struct {