aboutsummaryrefslogtreecommitdiff
path: root/web/ollama
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-07-27 08:56:01 -0400
committerXe Iaso <me@xeiaso.net>2024-07-27 08:56:01 -0400
commitd73a41507f8aab06b378fad1d30f4ccf827e65db (patch)
tree4753d52984e44a00cfda52da75e5b6e7f7e85169 /web/ollama
parent7b9862f9410b2ca7d952d6665be02330f63c7174 (diff)
downloadx-d73a41507f8aab06b378fad1d30f4ccf827e65db.tar.xz
x-d73a41507f8aab06b378fad1d30f4ccf827e65db.zip
cmd/mimi: conversate
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web/ollama')
-rw-r--r--web/ollama/llamaguard/llamaguard.go130
-rw-r--r--web/ollama/ollama.go61
2 files changed, 191 insertions, 0 deletions
diff --git a/web/ollama/llamaguard/llamaguard.go b/web/ollama/llamaguard/llamaguard.go
new file mode 100644
index 0000000..77b41a0
--- /dev/null
+++ b/web/ollama/llamaguard/llamaguard.go
@@ -0,0 +1,130 @@
+package llamaguard
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "within.website/x/web/ollama"
+)
+
+type Category string
+
+const (
+ S1 Category = "Violent Crimes"
+ S2 Category = "Non-Violent Crimes"
+ S3 Category = "Sex Crimes"
+ S4 Category = "Child Exploitation"
+ S5 Category = "Defamation"
+ S6 Category = "Specialized Advice"
+ S7 Category = "Privacy"
+ S8 Category = "Intellectual Property"
+ S9 Category = "Indiscriminate Weapons"
+ S10 Category = "Hate"
+ S11 Category = "Self-Harm"
+ S12 Category = "Sexual Content"
+ S13 Category = "Elections"
+ S14 Category = "Code Interpreter Abuse"
+
+ Unknown Category = "Unknown"
+)
+
+func (c Category) String() string {
+ return string(c)
+}
+
+func ParseCategory(s string) Category {
+ switch s {
+ case "S1":
+ return S1
+ case "S2":
+ return S2
+ case "S3":
+ return S3
+ case "S4":
+ return S4
+ case "S5":
+ return S5
+ case "S6":
+ return S6
+ case "S7":
+ return S7
+ case "S8":
+ return S8
+ case "S9":
+ return S9
+ case "S10":
+ return S10
+ case "S11":
+ return S11
+ case "S12":
+ return S12
+ case "S13":
+ return S13
+ case "S14":
+ return S14
+ default:
+ return ""
+ }
+}
+
+type Response struct {
+ IsSafe bool `json:"is_safe"`
+ ViolationCategories []Category `json:"violation_categories"`
+}
+
+func formatMessages(messages []ollama.Message) string {
+ var sb strings.Builder
+
+ for _, m := range messages {
+ switch m.Role {
+ case "user":
+ sb.WriteString("User: ")
+ case "assistant":
+ sb.WriteString("Agent: ")
+ }
+ sb.WriteString(m.Content)
+ sb.WriteString("\n\n")
+ }
+
+ return sb.String()
+}
+
+func Check(ctx context.Context, cli *ollama.Client, role, model string, messages []ollama.Message) (*Response, error) {
+ req := &ollama.GenerateRequest{
+ Model: model,
+ System: &role,
+ Prompt: formatMessages(messages),
+ KeepAlive: "60m",
+ }
+
+ resp, err := cli.Generate(ctx, req)
+ if err != nil {
+ return nil, fmt.Errorf("llamaguard: failed to generate response: %w", err)
+ }
+
+ if resp == nil {
+ return nil, fmt.Errorf("llamaguard: response was nil")
+ }
+
+ var result Response
+
+ resp.Response = strings.TrimSpace(resp.Response)
+ if resp.Response == "safe" {
+ result.IsSafe = true
+ return &result, nil
+ }
+
+ result.IsSafe = false
+
+ reasons := strings.SplitN(resp.Response, "\n", 2)
+ if len(reasons) != 2 {
+ return nil, fmt.Errorf("llamaguard: response was not in the expected format")
+ }
+
+ for _, r := range strings.Split(reasons[1], ",") {
+ result.ViolationCategories = append(result.ViolationCategories, ParseCategory(r))
+ }
+
+ return &result, nil
+}
diff --git a/web/ollama/ollama.go b/web/ollama/ollama.go
index 3f9870e..5825c2c 100644
--- a/web/ollama/ollama.go
+++ b/web/ollama/ollama.go
@@ -233,3 +233,64 @@ func (c *Client) Embeddings(ctx context.Context, er *EmbedRequest) (*EmbedRespon
return &result, nil
}
+
+type GenerateRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Images [][]byte `json:"images,omitempty"`
+ Options map[string]any `json:"options"`
+
+ Context []int `json:"context,omitempty"`
+ Format *string `json:"format,omitempty"`
+ Template *string `json:"template,omitempty"`
+ System *string `json:"system,omitempty"`
+ Stream bool `json:"stream"`
+ Raw bool `json:"raw"`
+ KeepAlive string `json:"keep_alive"`
+}
+
+type GenerateResponse struct {
+ Model string `json:"model"`
+ CreatedAt time.Time `json:"created_at"`
+ Response string `json:"response"`
+ Done bool `json:"done"`
+ Context []int `json:"context"`
+ TotalDuration int64 `json:"total_duration"`
+ LoadDuration int64 `json:"load_duration"`
+ PromptEvalCount int `json:"prompt_eval_count"`
+ PromptEvalDuration int64 `json:"prompt_eval_duration"`
+ EvalCount int `json:"eval_count"`
+ EvalDuration int64 `json:"eval_duration"`
+}
+
+func (c *Client) Generate(ctx context.Context, gr *GenerateRequest) (*GenerateResponse, error) {
+ buf := &bytes.Buffer{}
+ if err := json.NewEncoder(buf).Encode(gr); err != nil {
+ return nil, fmt.Errorf("ollama: error encoding request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/api/generate", buf)
+ if err != nil {
+ return nil, fmt.Errorf("ollama: error creating request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("ollama: error making request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var result GenerateResponse
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, fmt.Errorf("ollama: error decoding response: %w", err)
+ }
+
+ return &result, nil
+}