diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-11-16 20:28:57 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-01-30 18:26:31 -0500 |
| commit | bfc08e1e932552d1014fd1045cf074b547244854 (patch) | |
| tree | bd7e3fa3ceddce5b0136c36d6283811326e8fb5d /web | |
| parent | 846f8c3ab7884a0d4585e8ee8580e95695d6b5f4 (diff) | |
| download | x-bfc08e1e932552d1014fd1045cf074b547244854.tar.xz x-bfc08e1e932552d1014fd1045cf074b547244854.zip | |
web/openai: add dall-e API bindings
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
| -rw-r--r-- | web/openai/dalle/dalle.go | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/web/openai/dalle/dalle.go b/web/openai/dalle/dalle.go new file mode 100644 index 0000000..dc57b7a --- /dev/null +++ b/web/openai/dalle/dalle.go @@ -0,0 +1,113 @@ +package dalle + +import ( + "bytes" + "context" + "encoding/json" + "log/slog" + "net/http" + + "within.website/x/web" +) + +type Client struct { + apiKey string +} + +func New(apiKey string) Client { + return Client{apiKey: apiKey} +} + +type Model string + +const ( + DALLE2 = Model("dall-e-2") + DALLE3 = Model("dall-e-3") +) + +type Size string + +const ( + Size256 = Size("256x256") + Size512 = Size("512x512") + Size1024 = Size("1024x1024") + + // These are only supported with the dall-e-3 model. + SizeHDWide = Size("1792x1024") + SizeHDTall = Size("1024x1792") +) + +type Style string + +const ( + StyleVivid = Style("vivid") + StyleNatural = Style("natural") +) + +type Options struct { + Model Model `json:"model"` + Prompt string `json:"prompt"` + N *int `json:"n,omitempty"` + Quality *string `json:"quality,omitempty"` + Size *Size `json:"size"` + Style *Style `json:"style,omitempty"` + User *string `json:"user,omitempty"` +} + +type Image struct { + URL string `json:"url"` +} + +func (i Image) LogValue() slog.Value { + return slog.GroupValue( + slog.String("url", i.URL), + ) +} + +type Response struct { + Created int `json:"created"` + Data []Image `json:"data"` +} + +func (r Response) LogValue() slog.Value { + return slog.GroupValue( + slog.Int("created", r.Created), + slog.Any("data", r.Data), + ) +} + +func (c Client) Do(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+c.apiKey) + return http.DefaultClient.Do(req) +} + +func (c Client) GenerateImage(ctx context.Context, opts Options) (*Response, error) { + buf := bytes.NewBuffer(nil) + if err := json.NewEncoder(buf).Encode(opts); err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/images/generations", buf) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var result Response + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, err + } + + return &result, nil +} |
