diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-03-06 20:40:21 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-03-06 20:40:21 -0500 |
| commit | eede36ca2465a062b22415400d3066dd35e4a752 (patch) | |
| tree | 084531532b2e581fb551af4b8f97f655bdee873a /web | |
| parent | a56536566235adccde5778f153d19263a46c4c18 (diff) | |
| download | x-eede36ca2465a062b22415400d3066dd35e4a752.tar.xz x-eede36ca2465a062b22415400d3066dd35e4a752.zip | |
web/ollama: add embedding support
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
| -rw-r--r-- | web/ollama/ollama.go | 67 |
1 files changed, 53 insertions, 14 deletions
diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go index cecf304..3f9870e 100644 --- a/web/ollama/ollama.go +++ b/web/ollama/ollama.go @@ -36,13 +36,12 @@ type Message struct { } type CompleteRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Format *string `json:"format,omitempty"` - Template *string `json:"template,omitempty"` - Stream bool `json:"stream"` - Options map[string]any `json:"options"` - KeepAlive time.Duration `json:"keep_alive"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Format *string `json:"format,omitempty"` + Template *string `json:"template,omitempty"` + Stream bool `json:"stream"` + Options map[string]any `json:"options"` } type CompleteResponse struct { @@ -59,8 +58,6 @@ type CompleteResponse struct { } func (c *Client) Chat(ctx context.Context, inp *CompleteRequest) (*CompleteResponse, error) { - inp.KeepAlive = 24 * time.Hour - buf := &bytes.Buffer{} if err := json.NewEncoder(buf).Encode(inp); err != nil { return nil, fmt.Errorf("ollama: error encoding request: %w", err) @@ -105,11 +102,10 @@ func p[T any](v T) *T { // Hallucinate prompts the model to hallucinate a "valid" JSON response to the given input. func Hallucinate[T valid.Interface](ctx context.Context, c *Client, opts HallucinateOpts) (*T, error) { inp := &CompleteRequest{ - Model: opts.Model, - Messages: opts.Messages, - KeepAlive: 24 * time.Hour, - Format: p("json"), - Stream: true, + Model: opts.Model, + Messages: opts.Messages, + Format: p("json"), + Stream: true, } tries := 0 for tries <= 5 { @@ -194,3 +190,46 @@ func Hallucinate[T valid.Interface](ctx context.Context, c *Client, opts Halluci return nil, fmt.Errorf("ollama: failed to hallucinate a valid response after 5 tries") } + +type EmbedRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + + Options map[string]any `json:"options"` +} + +type EmbedResponse struct { + Embedding []float64 `json:"embedding"` +} + +func (c *Client) Embeddings(ctx context.Context, er *EmbedRequest) (*EmbedResponse, error) { + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(er); err != nil { + return nil, fmt.Errorf("ollama: error encoding request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/embeddings", buf) + if err != nil { + return nil, fmt.Errorf("ollama: error creating request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("ollama: error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var result EmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("ollama: error decoding response: %w", err) + } + + return &result, nil +} |
