aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2023-10-25 16:48:52 -0400
committerXe Iaso <me@xeiaso.net>2023-10-25 16:48:52 -0400
commit29334130dbd6da82c7fcbdfafce38509714b3ac3 (patch)
treeac53960bdc06dc5d057f3f68e072c6eed2c04062
parent6a79f04abca6ea56bdba6bdf6229d05e0fe9532a (diff)
downloadx-29334130dbd6da82c7fcbdfafce38509714b3ac3.tar.xz
x-29334130dbd6da82c7fcbdfafce38509714b3ac3.zip
llm: add new package for dealing with large language model formats
Signed-off-by: Xe Iaso <me@xeiaso.net>
-rw-r--r--llm/chatml.go36
-rw-r--r--llm/chatml_test.go28
-rw-r--r--llm/doc.go4
-rw-r--r--llm/functions.go58
-rw-r--r--llm/llama_instruct.go11
-rw-r--r--llm/llamacpp.go152
6 files changed, 289 insertions, 0 deletions
diff --git a/llm/chatml.go b/llm/chatml.go
new file mode 100644
index 0000000..fdc54cc
--- /dev/null
+++ b/llm/chatml.go
@@ -0,0 +1,36 @@
+package llm
+
+import (
+ "fmt"
+ "strings"
+)
+
+type Session struct {
+ Messages []ChatMLer `json:"messages"`
+}
+
+type ChatMLer interface {
+ ChatML() string
+}
+
+func (s Session) ChatML() string {
+ var sb strings.Builder
+
+ for _, message := range s.Messages {
+ fmt.Fprintf(&sb, "%s\n", message.ChatML())
+ }
+
+ return sb.String()
+}
+
+type Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+}
+
+func (m Message) ChatML() string {
+ if m.Content == "" {
+ return fmt.Sprintf("<|im_start|>%s\n", m.Role)
+ }
+ return fmt.Sprintf("<|im_start|>%s\n%s<|im_end|>", m.Role, m.Content)
+}
diff --git a/llm/chatml_test.go b/llm/chatml_test.go
new file mode 100644
index 0000000..2a1ff80
--- /dev/null
+++ b/llm/chatml_test.go
@@ -0,0 +1,28 @@
+package llm
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestChatML(t *testing.T) {
+ session := Session{
+ Messages: []ChatMLer{
+ Message{
+ Role: "user",
+ Content: "hello",
+ },
+ Message{
+ Role: "assistant",
+ },
+ },
+ }
+
+ expected := `<|im_start|>user
+hello<|im_end|>
+<|im_start|>assistant`
+
+ if strings.TrimSpace(session.ChatML()) != strings.TrimSpace(expected) {
+ t.Errorf("Expected\n\n%s\n\ngot\n\n%s", expected, session.ChatML())
+ }
+}
diff --git a/llm/doc.go b/llm/doc.go
new file mode 100644
index 0000000..222926b
--- /dev/null
+++ b/llm/doc.go
@@ -0,0 +1,4 @@
+/*
+Package llm is a collection of tools to automatically format prompts using variants of ChatML and other prompt formatting metasyntaxes.
+*/
+package llm \ No newline at end of file
diff --git a/llm/functions.go b/llm/functions.go
new file mode 100644
index 0000000..dbe919c
--- /dev/null
+++ b/llm/functions.go
@@ -0,0 +1,58 @@
+package llm
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+)
+
+type FunctionMessage struct {
+ Role string `json:"role"`
+ SystemPrompt string `json:"content"`
+ UserQuestion string `json:"user_question"`
+ Functions []Function `json:"functions"`
+}
+
+type Function struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Arguments []Argument `json:"arguments"`
+}
+
+type FunctionResponse struct {
+ Function string `json:"function"`
+ Arguments map[string]string `json:"arguments"`
+}
+
+type Argument struct {
+ Name string `json:"name"`
+ Type string `json:"type"`
+ Description string `json:"description"`
+}
+
+func (m FunctionMessage) ChatML() string {
+ var sb strings.Builder
+
+ fmt.Fprintf(&sb, "<s>[INST] <<SYS>>\n%s The following functions are available for you to fetch further data to answer user questions, if relevant:\n\n", m.SystemPrompt)
+ enc := json.NewEncoder(&sb)
+
+ for _, function := range m.Functions {
+ enc.Encode(function)
+ }
+
+ fmt.Fprintf(&sb, `
+ To call a function, respond - immediately and only - with a JSON object of the following format:
+ {
+ "function": "function_name",
+ "arguments": {
+ "argument1": "argument_value",
+ "argument2": "argument_value"
+ }
+ }
+ <</SYS>>
+
+ `)
+ fmt.Fprintf(&sb, "%s [/INST]", m.UserQuestion)
+
+ return sb.String()
+}
diff --git a/llm/llama_instruct.go b/llm/llama_instruct.go
new file mode 100644
index 0000000..64224f0
--- /dev/null
+++ b/llm/llama_instruct.go
@@ -0,0 +1,11 @@
+package llm
+
+import "fmt"
+
+type LlamaInstruct struct {
+ Content string `json:"content"`
+}
+
+func (m LlamaInstruct) ChatML() string {
+ return fmt.Sprintf("[INST]\n%s\n[/INST]", m.Content)
+}
diff --git a/llm/llamacpp.go b/llm/llamacpp.go
new file mode 100644
index 0000000..956a235
--- /dev/null
+++ b/llm/llamacpp.go
@@ -0,0 +1,152 @@
+package llm
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "log/slog"
+ "net/http"
+
+ "within.website/x/web"
+)
+
+type Client struct {
+ cli *http.Client
+ serverURL string
+}
+
+func NewClient(cli *http.Client, serverURL string) *Client {
+ return &Client{
+ cli: cli,
+ serverURL: serverURL,
+ }
+}
+
+func (c *Client) ExecSession(session Session) (*LLAMAResponse, error) {
+ opts := DefaultLLAMAOpts()
+ opts.Prompt = session.ChatML()
+ return c.Predict(opts)
+}
+
+func (c *Client) Predict(opts *LLAMAOpts) (*LLAMAResponse, error) {
+ jsonData, err := json.Marshal(opts)
+ if err != nil {
+ return nil, err
+ }
+ // Make a POST request to the server
+ resp, err := c.cli.Post(c.serverURL, "application/json", bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ // Check the response status code
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ var result LLAMAResponse
+
+ if err := json.Unmarshal(data, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+type LLAMAOpts struct {
+ Temperature float64 `json:"temperature"`
+ TopK int `json:"top_k"`
+ TopP float64 `json:"top_p"`
+ Stream bool `json:"stream"`
+ Prompt string `json:"prompt"`
+ RepeatPenalty float64 `json:"repeat_penalty"`
+ RepeatLastN int `json:"repeat_last_n"`
+ Mirostat int `json:"mirostat"`
+ NPredict int `json:"n_predict"`
+}
+
+func DefaultLLAMAOpts() *LLAMAOpts {
+ return &LLAMAOpts{
+ Temperature: 0.8,
+ TopK: 40,
+ TopP: 0.9,
+ Stream: false,
+ RepeatPenalty: 1.15,
+ RepeatLastN: 512,
+ Mirostat: 2,
+ NPredict: 2048,
+ }
+}
+
+type LLAMAResponse struct {
+ Content string `json:"content"`
+ GenerationSettings GenerationSettings `json:"generation_settings"`
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Stop bool `json:"stop"`
+ StoppedEos bool `json:"stopped_eos"`
+ StoppedLimit bool `json:"stopped_limit"`
+ StoppedWord bool `json:"stopped_word"`
+ StoppingWord string `json:"stopping_word"`
+ Timings Timings `json:"timings"`
+ TokensCached int `json:"tokens_cached"`
+ TokensEvaluated int `json:"tokens_evaluated"`
+ TokensPredicted int `json:"tokens_predicted"`
+ Truncated bool `json:"truncated"`
+}
+
+type GenerationSettings struct {
+ FrequencyPenalty float64 `json:"frequency_penalty"`
+ Grammar string `json:"grammar"`
+ IgnoreEos bool `json:"ignore_eos"`
+ LogitBias []any `json:"logit_bias"`
+ Mirostat int `json:"mirostat"`
+ MirostatEta float64 `json:"mirostat_eta"`
+ MirostatTau float64 `json:"mirostat_tau"`
+ Model string `json:"model"`
+ NCtx int `json:"n_ctx"`
+ NKeep int `json:"n_keep"`
+ NPredict int `json:"n_predict"`
+ NProbs int `json:"n_probs"`
+ PenalizeNl bool `json:"penalize_nl"`
+ PresencePenalty float64 `json:"presence_penalty"`
+ RepeatLastN int `json:"repeat_last_n"`
+ RepeatPenalty float64 `json:"repeat_penalty"`
+ Seed int64 `json:"seed"`
+ Stop []any `json:"stop"`
+ Stream bool `json:"stream"`
+ Temp float64 `json:"temp"`
+ TfsZ float64 `json:"tfs_z"`
+ TopK int `json:"top_k"`
+ TopP float64 `json:"top_p"`
+ TypicalP float64 `json:"typical_p"`
+}
+
+type Timings struct {
+ PredictedMs float64 `json:"predicted_ms"`
+ PredictedN int `json:"predicted_n"`
+ PredictedPerSecond float64 `json:"predicted_per_second"`
+ PredictedPerTokenMs float64 `json:"predicted_per_token_ms"`
+ PromptMs float64 `json:"prompt_ms"`
+ PromptN int `json:"prompt_n"`
+ PromptPerSecond float64 `json:"prompt_per_second"`
+ PromptPerTokenMs float64 `json:"prompt_per_token_ms"`
+}
+
+func (t Timings) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.Float64("predicted_ms", t.PredictedMs),
+ slog.Int("predicted_n", t.PredictedN),
+ slog.Float64("predicted_per_second", t.PredictedPerSecond),
+ slog.Float64("predicted_per_token_ms", t.PredictedPerTokenMs),
+ slog.Float64("prompt_ms", t.PromptMs),
+ slog.Int("prompt_n", t.PromptN),
+ slog.Float64("prompt_per_second", t.PromptPerSecond),
+ slog.Float64("prompt_per_token_ms", t.PromptPerTokenMs),
+ )
+}