aboutsummaryrefslogtreecommitdiff
path: root/web
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-03-06 20:40:21 -0500
committerXe Iaso <me@xeiaso.net>2024-03-06 20:40:21 -0500
commiteede36ca2465a062b22415400d3066dd35e4a752 (patch)
tree084531532b2e581fb551af4b8f97f655bdee873a /web
parenta56536566235adccde5778f153d19263a46c4c18 (diff)
downloadx-eede36ca2465a062b22415400d3066dd35e4a752.tar.xz
x-eede36ca2465a062b22415400d3066dd35e4a752.zip
web/ollama: add embedding support
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
-rw-r--r--web/ollama/ollama.go67
1 files changed, 53 insertions, 14 deletions
diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go
index cecf304..3f9870e 100644
--- a/web/ollama/ollama.go
+++ b/web/ollama/ollama.go
@@ -36,13 +36,12 @@ type Message struct {
}
type CompleteRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Format *string `json:"format,omitempty"`
- Template *string `json:"template,omitempty"`
- Stream bool `json:"stream"`
- Options map[string]any `json:"options"`
- KeepAlive time.Duration `json:"keep_alive"`
+ Model string `json:"model"`
+ Messages []Message `json:"messages"`
+ Format *string `json:"format,omitempty"`
+ Template *string `json:"template,omitempty"`
+ Stream bool `json:"stream"`
+ Options map[string]any `json:"options"`
}
type CompleteResponse struct {
@@ -59,8 +58,6 @@ type CompleteResponse struct {
}
func (c *Client) Chat(ctx context.Context, inp *CompleteRequest) (*CompleteResponse, error) {
- inp.KeepAlive = 24 * time.Hour
-
buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(inp); err != nil {
return nil, fmt.Errorf("ollama: error encoding request: %w", err)
@@ -105,11 +102,10 @@ func p[T any](v T) *T {
// Hallucinate prompts the model to hallucinate a "valid" JSON response to the given input.
func Hallucinate[T valid.Interface](ctx context.Context, c *Client, opts HallucinateOpts) (*T, error) {
inp := &CompleteRequest{
- Model: opts.Model,
- Messages: opts.Messages,
- KeepAlive: 24 * time.Hour,
- Format: p("json"),
- Stream: true,
+ Model: opts.Model,
+ Messages: opts.Messages,
+ Format: p("json"),
+ Stream: true,
}
tries := 0
for tries <= 5 {
@@ -194,3 +190,46 @@ func Hallucinate[T valid.Interface](ctx context.Context, c *Client, opts Halluci
return nil, fmt.Errorf("ollama: failed to hallucinate a valid response after 5 tries")
}
+
+type EmbedRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+
+ Options map[string]any `json:"options"`
+}
+
+type EmbedResponse struct {
+ Embedding []float64 `json:"embedding"`
+}
+
+func (c *Client) Embeddings(ctx context.Context, er *EmbedRequest) (*EmbedResponse, error) {
+ buf := &bytes.Buffer{}
+ if err := json.NewEncoder(buf).Encode(er); err != nil {
+ return nil, fmt.Errorf("ollama: error encoding request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/embeddings", buf)
+ if err != nil {
+ return nil, fmt.Errorf("ollama: error creating request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("ollama: error making request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var result EmbedResponse
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, fmt.Errorf("ollama: error decoding response: %w", err)
+ }
+
+ return &result, nil
+}