diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-10-25 16:48:52 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-10-25 16:48:52 -0400 |
| commit | 29334130dbd6da82c7fcbdfafce38509714b3ac3 (patch) | |
| tree | ac53960bdc06dc5d057f3f68e072c6eed2c04062 | |
| parent | 6a79f04abca6ea56bdba6bdf6229d05e0fe9532a (diff) | |
| download | x-29334130dbd6da82c7fcbdfafce38509714b3ac3.tar.xz x-29334130dbd6da82c7fcbdfafce38509714b3ac3.zip | |
llm: add new package for dealing with large language model formats
Signed-off-by: Xe Iaso <me@xeiaso.net>
| -rw-r--r-- | llm/chatml.go | 36 | ||||
| -rw-r--r-- | llm/chatml_test.go | 28 | ||||
| -rw-r--r-- | llm/doc.go | 4 | ||||
| -rw-r--r-- | llm/functions.go | 58 | ||||
| -rw-r--r-- | llm/llama_instruct.go | 11 | ||||
| -rw-r--r-- | llm/llamacpp.go | 152 |
6 files changed, 289 insertions, 0 deletions
diff --git a/llm/chatml.go b/llm/chatml.go new file mode 100644 index 0000000..fdc54cc --- /dev/null +++ b/llm/chatml.go @@ -0,0 +1,36 @@ +package llm + +import ( + "fmt" + "strings" +) + +type Session struct { + Messages []ChatMLer `json:"messages"` +} + +type ChatMLer interface { + ChatML() string +} + +func (s Session) ChatML() string { + var sb strings.Builder + + for _, message := range s.Messages { + fmt.Fprintf(&sb, "%s\n", message.ChatML()) + } + + return sb.String() +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +func (m Message) ChatML() string { + if m.Content == "" { + return fmt.Sprintf("<|im_start|>%s\n", m.Role) + } + return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>", m.Role, m.Content) +} diff --git a/llm/chatml_test.go b/llm/chatml_test.go new file mode 100644 index 0000000..2a1ff80 --- /dev/null +++ b/llm/chatml_test.go @@ -0,0 +1,28 @@ +package llm + +import ( + "strings" + "testing" +) + +func TestChatML(t *testing.T) { + session := Session{ + Messages: []ChatMLer{ + Message{ + Role: "user", + Content: "hello", + }, + Message{ + Role: "assistant", + }, + }, + } + + expected := `<|im_start|>user +hello<|im_end|> +<|im_start|>assistant` + + if strings.TrimSpace(session.ChatML()) != strings.TrimSpace(expected) { + t.Errorf("Expected\n\n%s\n\ngot\n\n%s", expected, session.ChatML()) + } +} diff --git a/llm/doc.go b/llm/doc.go new file mode 100644 index 0000000..222926b --- /dev/null +++ b/llm/doc.go @@ -0,0 +1,4 @@ +/* +Package llm is a collection of tools to automatically format prompts using variants of ChatML and other prompt formatting metasyntaxes. +*/ +package llm
\ No newline at end of file diff --git a/llm/functions.go b/llm/functions.go new file mode 100644 index 0000000..dbe919c --- /dev/null +++ b/llm/functions.go @@ -0,0 +1,58 @@ +package llm + +import ( + "encoding/json" + "fmt" + "strings" +) + +type FunctionMessage struct { + Role string `json:"role"` + SystemPrompt string `json:"content"` + UserQuestion string `json:"user_question"` + Functions []Function `json:"functions"` +} + +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Arguments []Argument `json:"arguments"` +} + +type FunctionResponse struct { + Function string `json:"function"` + Arguments map[string]string `json:"arguments"` +} + +type Argument struct { + Name string `json:"name"` + Type string `json:"type"` + Description string `json:"description"` +} + +func (m FunctionMessage) ChatML() string { + var sb strings.Builder + + fmt.Fprintf(&sb, "<s>[INST] <<SYS>>\n%s The following functions are available for you to fetch further data to answer user questions, if relevant:\n\n", m.SystemPrompt) + enc := json.NewEncoder(&sb) + + for _, function := range m.Functions { + enc.Encode(function) + } + + fmt.Fprintf(&sb, ` + To call a function, respond - immediately and only - with a JSON object of the following format: + { + "function": "function_name", + "arguments": { + "argument1": "argument_value", + "argument2": "argument_value" + } + } + <</SYS>> + + `) + fmt.Fprintf(&sb, "%s [/INST]", m.UserQuestion) + + return sb.String() +} diff --git a/llm/llama_instruct.go b/llm/llama_instruct.go new file mode 100644 index 0000000..64224f0 --- /dev/null +++ b/llm/llama_instruct.go @@ -0,0 +1,11 @@ +package llm + +import "fmt" + +type LlamaInstruct struct { + Content string `json:"content"` +} + +func (m LlamaInstruct) ChatML() string { + return fmt.Sprintf("[INST]\n%s\n[/INST]", m.Content) +} diff --git a/llm/llamacpp.go b/llm/llamacpp.go new file mode 100644 index 0000000..956a235 --- /dev/null +++ b/llm/llamacpp.go @@ -0,0 +1,152 @@ +package llm + +import ( + "bytes" + "encoding/json" + "io" + "log/slog" + "net/http" + + "within.website/x/web" +) + +type Client struct { + cli *http.Client + serverURL string +} + +func NewClient(cli *http.Client, serverURL string) *Client { + return &Client{ + cli: cli, + serverURL: serverURL, + } +} + +func (c *Client) ExecSession(session Session) (*LLAMAResponse, error) { + opts := DefaultLLAMAOpts() + opts.Prompt = session.ChatML() + return c.Predict(opts) +} + +func (c *Client) Predict(opts *LLAMAOpts) (*LLAMAResponse, error) { + jsonData, err := json.Marshal(opts) + if err != nil { + return nil, err + } + // Make a POST request to the server + resp, err := c.cli.Post(c.serverURL, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + // Check the response status code + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var result LLAMAResponse + + if err := json.Unmarshal(data, &result); err != nil { + return nil, err + } + + return &result, nil +} + +type LLAMAOpts struct { + Temperature float64 `json:"temperature"` + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + Stream bool `json:"stream"` + Prompt string `json:"prompt"` + RepeatPenalty float64 `json:"repeat_penalty"` + RepeatLastN int `json:"repeat_last_n"` + Mirostat int `json:"mirostat"` + NPredict int `json:"n_predict"` +} + +func DefaultLLAMAOpts() *LLAMAOpts { + return &LLAMAOpts{ + Temperature: 0.8, + TopK: 40, + TopP: 0.9, + Stream: false, + RepeatPenalty: 1.15, + RepeatLastN: 512, + Mirostat: 2, + NPredict: 2048, + } +} + +type LLAMAResponse struct { + Content string `json:"content"` + GenerationSettings GenerationSettings `json:"generation_settings"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Stop bool `json:"stop"` + StoppedEos bool `json:"stopped_eos"` + StoppedLimit bool `json:"stopped_limit"` + StoppedWord bool `json:"stopped_word"` + StoppingWord string `json:"stopping_word"` + Timings Timings `json:"timings"` + TokensCached int `json:"tokens_cached"` + TokensEvaluated int `json:"tokens_evaluated"` + TokensPredicted int `json:"tokens_predicted"` + Truncated bool `json:"truncated"` +} + +type GenerationSettings struct { + FrequencyPenalty float64 `json:"frequency_penalty"` + Grammar string `json:"grammar"` + IgnoreEos bool `json:"ignore_eos"` + LogitBias []any `json:"logit_bias"` + Mirostat int `json:"mirostat"` + MirostatEta float64 `json:"mirostat_eta"` + MirostatTau float64 `json:"mirostat_tau"` + Model string `json:"model"` + NCtx int `json:"n_ctx"` + NKeep int `json:"n_keep"` + NPredict int `json:"n_predict"` + NProbs int `json:"n_probs"` + PenalizeNl bool `json:"penalize_nl"` + PresencePenalty float64 `json:"presence_penalty"` + RepeatLastN int `json:"repeat_last_n"` + RepeatPenalty float64 `json:"repeat_penalty"` + Seed int64 `json:"seed"` + Stop []any `json:"stop"` + Stream bool `json:"stream"` + Temp float64 `json:"temp"` + TfsZ float64 `json:"tfs_z"` + TopK int `json:"top_k"` + TopP float64 `json:"top_p"` + TypicalP float64 `json:"typical_p"` +} + +type Timings struct { + PredictedMs float64 `json:"predicted_ms"` + PredictedN int `json:"predicted_n"` + PredictedPerSecond float64 `json:"predicted_per_second"` + PredictedPerTokenMs float64 `json:"predicted_per_token_ms"` + PromptMs float64 `json:"prompt_ms"` + PromptN int `json:"prompt_n"` + PromptPerSecond float64 `json:"prompt_per_second"` + PromptPerTokenMs float64 `json:"prompt_per_token_ms"` +} + +func (t Timings) LogValue() slog.Value { + return slog.GroupValue( + slog.Float64("predicted_ms", t.PredictedMs), + slog.Int("predicted_n", t.PredictedN), + slog.Float64("predicted_per_second", t.PredictedPerSecond), + slog.Float64("predicted_per_token_ms", t.PredictedPerTokenMs), + slog.Float64("prompt_ms", t.PromptMs), + slog.Int("prompt_n", t.PromptN), + slog.Float64("prompt_per_second", t.PromptPerSecond), + slog.Float64("prompt_per_token_ms", t.PromptPerTokenMs), + ) +} |
