aboutsummaryrefslogtreecommitdiff
path: root/llm/multillm/openai.go
blob: a28f5d0d3e8336a842c93618ad1adfae780a9fbf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package multillm

import (
	"context"
	"fmt"

	"within.website/x/llm"
	"within.website/x/web/openai/chatgpt"
)

type OpenAI struct {
	*chatgpt.Client
}

func convertToChatGPTMessage(m llm.Message) chatgpt.Message {
	return chatgpt.Message{
		Role:    m.Role,
		Content: m.Content,
	}
}

func (oaic *OpenAI) Chat(ctx context.Context, req *Request) (*Response, error) {
	chatReq := chatgpt.Request{
		Model:       req.Model,
		Temperature: req.Temperature,
		Seed:        req.RandomSeed,
		Messages:    make([]chatgpt.Message, len(req.Messages)),
	}

	for i, m := range req.Messages {
		chatReq.Messages[i] = convertToChatGPTMessage(m)
	}

	chatResp, err := oaic.Client.Complete(ctx, chatReq)
	if err != nil {
		return nil, fmt.Errorf("multillm: error chatting: %w", err)
	}

	return &Response{
		Response: llm.Message{
			Role:    chatResp.Choices[0].Message.Role,
			Content: chatResp.Choices[0].Message.Content,
		},
	}, nil
}