diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-07-27 08:56:01 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-07-27 08:56:01 -0400 |
| commit | d73a41507f8aab06b378fad1d30f4ccf827e65db (patch) | |
| tree | 4753d52984e44a00cfda52da75e5b6e7f7e85169 /web/ollama | |
| parent | 7b9862f9410b2ca7d952d6665be02330f63c7174 (diff) | |
| download | x-d73a41507f8aab06b378fad1d30f4ccf827e65db.tar.xz x-d73a41507f8aab06b378fad1d30f4ccf827e65db.zip | |
cmd/mimi: conversate
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web/ollama')
| -rw-r--r-- | web/ollama/llamaguard/llamaguard.go | 130 | ||||
| -rw-r--r-- | web/ollama/ollama.go | 61 |
2 files changed, 191 insertions, 0 deletions
diff --git a/web/ollama/llamaguard/llamaguard.go b/web/ollama/llamaguard/llamaguard.go new file mode 100644 index 0000000..77b41a0 --- /dev/null +++ b/web/ollama/llamaguard/llamaguard.go @@ -0,0 +1,130 @@ +package llamaguard + +import ( + "context" + "fmt" + "strings" + + "within.website/x/web/ollama" +) + +type Category string + +const ( + S1 Category = "Violent Crimes" + S2 Category = "Non-Violent Crimes" + S3 Category = "Sex Crimes" + S4 Category = "Child Exploitation" + S5 Category = "Defamation" + S6 Category = "Specialized Advice" + S7 Category = "Privacy" + S8 Category = "Intellectual Property" + S9 Category = "Indiscriminate Weapons" + S10 Category = "Hate" + S11 Category = "Self-Harm" + S12 Category = "Sexual Content" + S13 Category = "Elections" + S14 Category = "Code Interpreter Abuse" + + Unknown Category = "Unknown" +) + +func (c Category) String() string { + return string(c) +} + +func ParseCategory(s string) Category { + switch s { + case "S1": + return S1 + case "S2": + return S2 + case "S3": + return S3 + case "S4": + return S4 + case "S5": + return S5 + case "S6": + return S6 + case "S7": + return S7 + case "S8": + return S8 + case "S9": + return S9 + case "S10": + return S10 + case "S11": + return S11 + case "S12": + return S12 + case "S13": + return S13 + case "S14": + return S14 + default: + return "" + } +} + +type Response struct { + IsSafe bool `json:"is_safe"` + ViolationCategories []Category `json:"violation_categories"` +} + +func formatMessages(messages []ollama.Message) string { + var sb strings.Builder + + for _, m := range messages { + switch m.Role { + case "user": + sb.WriteString("User: ") + case "assistant": + sb.WriteString("Agent: ") + } + sb.WriteString(m.Content) + sb.WriteString("\n\n") + } + + return sb.String() +} + +func Check(ctx context.Context, cli *ollama.Client, role, model string, messages []ollama.Message) (*Response, error) { + req := &ollama.GenerateRequest{ + Model: model, + System: &role, + Prompt: formatMessages(messages), + KeepAlive: "60m", + } + + resp, err := cli.Generate(ctx, req) + if err != nil { + return nil, fmt.Errorf("llamaguard: failed to generate response: %w", err) + } + + if resp == nil { + return nil, fmt.Errorf("llamaguard: response was nil") + } + + var result Response + + resp.Response = strings.TrimSpace(resp.Response) + if resp.Response == "safe" { + result.IsSafe = true + return &result, nil + } + + result.IsSafe = false + + reasons := strings.SplitN(resp.Response, "\n", 2) + if len(reasons) != 2 { + return nil, fmt.Errorf("llamaguard: response was not in the expected format") + } + + for _, r := range strings.Split(reasons[1], ",") { + result.ViolationCategories = append(result.ViolationCategories, ParseCategory(r)) + } + + return &result, nil +} diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go index 3f9870e..5825c2c 100644 --- a/web/ollama/ollama.go +++ b/web/ollama/ollama.go @@ -233,3 +233,64 @@ func (c *Client) Embeddings(ctx context.Context, er *EmbedRequest) (*EmbedRespon return &result, nil } + +type GenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Images [][]byte `json:"images,omitempty"` + Options map[string]any `json:"options"` + + Context []int `json:"context,omitempty"` + Format *string `json:"format,omitempty"` + Template *string `json:"template,omitempty"` + System *string `json:"system,omitempty"` + Stream bool `json:"stream"` + Raw bool `json:"raw"` + KeepAlive string `json:"keep_alive"` +} + +type GenerateResponse struct { + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Response string `json:"response"` + Done bool `json:"done"` + Context []int `json:"context"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration int64 `json:"eval_duration"` +} + +func (c *Client) Generate(ctx context.Context, gr *GenerateRequest) (*GenerateResponse, error) { + buf := &bytes.Buffer{} + if err := json.NewEncoder(buf).Encode(gr); err != nil { + return nil, fmt.Errorf("ollama: error encoding request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/generate", buf) + if err != nil { + return nil, fmt.Errorf("ollama: error creating request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("ollama: error making request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var result GenerateResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("ollama: error decoding response: %w", err) + } + + return &result, nil +} |
