aboutsummaryrefslogtreecommitdiff
path: root/llm/llamacpp.go
blob: 956a2354fe8f43825d72f3033d5a91ce24403684 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package llm

import (
	"bytes"
	"encoding/json"
	"io"
	"log/slog"
	"net/http"

	"within.website/x/web"
)

type Client struct {
	cli       *http.Client
	serverURL string
}

func NewClient(cli *http.Client, serverURL string) *Client {
	return &Client{
		cli:       cli,
		serverURL: serverURL,
	}
}

func (c *Client) ExecSession(session Session) (*LLAMAResponse, error) {
	opts := DefaultLLAMAOpts()
	opts.Prompt = session.ChatML()
	return c.Predict(opts)
}

func (c *Client) Predict(opts *LLAMAOpts) (*LLAMAResponse, error) {
	jsonData, err := json.Marshal(opts)
	if err != nil {
		return nil, err
	}
	// Make a POST request to the server
	resp, err := c.cli.Post(c.serverURL, "application/json", bytes.NewBuffer(jsonData))
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	// Check the response status code
	if resp.StatusCode != http.StatusOK {
		return nil, web.NewError(http.StatusOK, resp)
	}

	data, err := io.ReadAll(resp.Body)
	if err != nil {
		return nil, err
	}

	var result LLAMAResponse

	if err := json.Unmarshal(data, &result); err != nil {
		return nil, err
	}

	return &result, nil
}

type LLAMAOpts struct {
	Temperature   float64 `json:"temperature"`
	TopK          int     `json:"top_k"`
	TopP          float64 `json:"top_p"`
	Stream        bool    `json:"stream"`
	Prompt        string  `json:"prompt"`
	RepeatPenalty float64 `json:"repeat_penalty"`
	RepeatLastN   int     `json:"repeat_last_n"`
	Mirostat      int     `json:"mirostat"`
	NPredict      int     `json:"n_predict"`
}

func DefaultLLAMAOpts() *LLAMAOpts {
	return &LLAMAOpts{
		Temperature:   0.8,
		TopK:          40,
		TopP:          0.9,
		Stream:        false,
		RepeatPenalty: 1.15,
		RepeatLastN:   512,
		Mirostat:      2,
		NPredict:      2048,
	}
}

type LLAMAResponse struct {
	Content            string             `json:"content"`
	GenerationSettings GenerationSettings `json:"generation_settings"`
	Model              string             `json:"model"`
	Prompt             string             `json:"prompt"`
	Stop               bool               `json:"stop"`
	StoppedEos         bool               `json:"stopped_eos"`
	StoppedLimit       bool               `json:"stopped_limit"`
	StoppedWord        bool               `json:"stopped_word"`
	StoppingWord       string             `json:"stopping_word"`
	Timings            Timings            `json:"timings"`
	TokensCached       int                `json:"tokens_cached"`
	TokensEvaluated    int                `json:"tokens_evaluated"`
	TokensPredicted    int                `json:"tokens_predicted"`
	Truncated          bool               `json:"truncated"`
}

type GenerationSettings struct {
	FrequencyPenalty float64 `json:"frequency_penalty"`
	Grammar          string  `json:"grammar"`
	IgnoreEos        bool    `json:"ignore_eos"`
	LogitBias        []any   `json:"logit_bias"`
	Mirostat         int     `json:"mirostat"`
	MirostatEta      float64 `json:"mirostat_eta"`
	MirostatTau      float64 `json:"mirostat_tau"`
	Model            string  `json:"model"`
	NCtx             int     `json:"n_ctx"`
	NKeep            int     `json:"n_keep"`
	NPredict         int     `json:"n_predict"`
	NProbs           int     `json:"n_probs"`
	PenalizeNl       bool    `json:"penalize_nl"`
	PresencePenalty  float64 `json:"presence_penalty"`
	RepeatLastN      int     `json:"repeat_last_n"`
	RepeatPenalty    float64 `json:"repeat_penalty"`
	Seed             int64   `json:"seed"`
	Stop             []any   `json:"stop"`
	Stream           bool    `json:"stream"`
	Temp             float64 `json:"temp"`
	TfsZ             float64 `json:"tfs_z"`
	TopK             int     `json:"top_k"`
	TopP             float64 `json:"top_p"`
	TypicalP         float64 `json:"typical_p"`
}

type Timings struct {
	PredictedMs         float64 `json:"predicted_ms"`
	PredictedN          int     `json:"predicted_n"`
	PredictedPerSecond  float64 `json:"predicted_per_second"`
	PredictedPerTokenMs float64 `json:"predicted_per_token_ms"`
	PromptMs            float64 `json:"prompt_ms"`
	PromptN             int     `json:"prompt_n"`
	PromptPerSecond     float64 `json:"prompt_per_second"`
	PromptPerTokenMs    float64 `json:"prompt_per_token_ms"`
}

func (t Timings) LogValue() slog.Value {
	return slog.GroupValue(
		slog.Float64("predicted_ms", t.PredictedMs),
		slog.Int("predicted_n", t.PredictedN),
		slog.Float64("predicted_per_second", t.PredictedPerSecond),
		slog.Float64("predicted_per_token_ms", t.PredictedPerTokenMs),
		slog.Float64("prompt_ms", t.PromptMs),
		slog.Int("prompt_n", t.PromptN),
		slog.Float64("prompt_per_second", t.PromptPerSecond),
		slog.Float64("prompt_per_token_ms", t.PromptPerTokenMs),
	)
}