diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-04-05 09:36:01 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-04-05 11:52:49 -0400 |
| commit | 0b484d1f236dadde16d84563103e1ba53879e0ee (patch) | |
| tree | 3f088a7b06ffa7087fdd354ab69b0ae3cd24ad11 /cmd | |
| parent | 22cb91291121474af71ffc5fbb7dfacafb13578c (diff) | |
| download | x-0b484d1f236dadde16d84563103e1ba53879e0ee.tar.xz x-0b484d1f236dadde16d84563103e1ba53879e0ee.zip | |
cmd: add arsene-analysis command for seeing if the shitpost was right
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/arsene-analysis/.gitignore | 2 | ||||
| -rw-r--r-- | cmd/arsene-analysis/main.go | 155 |
2 files changed, 157 insertions, 0 deletions
diff --git a/cmd/arsene-analysis/.gitignore b/cmd/arsene-analysis/.gitignore new file mode 100644 index 0000000..29a7166 --- /dev/null +++ b/cmd/arsene-analysis/.gitignore @@ -0,0 +1,2 @@ +*.csv + diff --git a/cmd/arsene-analysis/main.go b/cmd/arsene-analysis/main.go new file mode 100644 index 0000000..7cc90b9 --- /dev/null +++ b/cmd/arsene-analysis/main.go @@ -0,0 +1,155 @@ +package main + +import ( + "context" + "encoding/csv" + "flag" + "fmt" + "log" + "log/slog" + "os" + "strconv" + "time" + + "within.website/x/internal" + "within.website/x/web/ollama" +) + +var ( + foutName = flag.String("out", "enriched.csv", "output file name") + ollamaHost = flag.String("ollama-host", "http://xe-inference.flycast", "ollama host") + ollamaModel = flag.String("ollama-model", "nous-hermes2-mixtral:8x7b-dpo-q5_K_M", "ollama model") + subsetFile = flag.String("subset", "", "subset CSV file to use") +) + +type sentimentResponse struct { + Sentiment string `json:"sentiment"` +} + +func (sr sentimentResponse) Valid() error { + if sr.Sentiment != "positive" && sr.Sentiment != "negative" && sr.Sentiment != "neutral" { + return fmt.Errorf("invalid sentiment %q", sr.Sentiment) + } + + return nil +} + +func main() { + internal.HandleStartup() + + fin, err := os.Open(*subsetFile) + if err != nil { + log.Fatal(err) + } + defer fin.Close() + + fout, err := os.Create(*foutName) + if err != nil { + log.Fatal(err) + } + defer fout.Close() + + w := csv.NewWriter(fout) + w.Write([]string{"id", "price_change", "sentiment"}) + + cli := ollama.NewClient(*ollamaHost) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Hour) + defer cancel() + + r := csv.NewReader(fin) + for { + row, err := r.Read() + if err != nil { + break + } + + //slog.Debug("got", "row", row) + + sr, err := ParseSubsetRow(row) + if err != nil { + slog.Error("failed to parse row", "err", err) + continue + } + + sen, err := ollama.Hallucinate[sentimentResponse](ctx, cli, ollama.HallucinateOpts{ + Model: *ollamaModel, + Messages: []ollama.Message{ + { + Role: "system", + Content: `Rate the sentiment of the following text. If the sentiment is positive, return this JSON object: +{"sentiment":"positive"} +If the sentiment is negative, return this JSON object: +{"sentiment":"negative"} +If there is neither a positive nor a negative sentiment, return this JSON object: +{"sentiment":"neutral"} +DO NOT send any whitespace or newlines in the JSON object.`, + }, + { + Role: "user", + Content: sr.Body, + }, + }, + }) + if err != nil { + slog.Error("failed to chat", "err", err) + continue + } + + priceChange := "" + + if sr.PrevPrice > sr.AfterPrice { + priceChange = "negative" + } else if sr.PrevPrice < sr.AfterPrice { + priceChange = "positive" + } else { + priceChange = "neutral" + } + + w.Write([]string{ + strconv.Itoa(sr.ID), + priceChange, + sen.Sentiment, + }) + w.Flush() + } + + w.Flush() +} + +type SubsetRow struct { + ID int `json:"id"` + Title string `json:"title"` + Body string `json:"body"` + PrevPrice float64 `json:"prev_price"` + AfterPrice float64 `json:"after_price"` +} + +func ParseSubsetRow(data []string) (*SubsetRow, error) { + if len(data) != 5 { + return nil, fmt.Errorf("expected 5 fields, got %d", len(data)) + } + + id, err := strconv.Atoi(data[0]) + if err != nil { + return nil, fmt.Errorf("id: %w", err) + } + + prevPrice, err := strconv.ParseFloat(data[3], 64) + if err != nil { + return nil, fmt.Errorf("prev_price: %w", err) + } + + afterPrice, err := strconv.ParseFloat(data[4], 64) + if err != nil { + return nil, fmt.Errorf("after_price: %w", err) + } + + return &SubsetRow{ + ID: id, + Title: data[1], + Body: data[2], + PrevPrice: prevPrice, + AfterPrice: afterPrice, + }, nil +} |
