aboutsummaryrefslogtreecommitdiff
path: root/web
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2023-11-16 20:28:57 -0500
committerXe Iaso <me@xeiaso.net>2024-01-30 18:26:31 -0500
commitbfc08e1e932552d1014fd1045cf074b547244854 (patch)
treebd7e3fa3ceddce5b0136c36d6283811326e8fb5d /web
parent846f8c3ab7884a0d4585e8ee8580e95695d6b5f4 (diff)
downloadx-bfc08e1e932552d1014fd1045cf074b547244854.tar.xz
x-bfc08e1e932552d1014fd1045cf074b547244854.zip
web/openai: add dall-e API bindings
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
-rw-r--r--web/openai/dalle/dalle.go113
1 files changed, 113 insertions, 0 deletions
diff --git a/web/openai/dalle/dalle.go b/web/openai/dalle/dalle.go
new file mode 100644
index 0000000..dc57b7a
--- /dev/null
+++ b/web/openai/dalle/dalle.go
@@ -0,0 +1,113 @@
+package dalle
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "log/slog"
+ "net/http"
+
+ "within.website/x/web"
+)
+
+type Client struct {
+ apiKey string
+}
+
+func New(apiKey string) Client {
+ return Client{apiKey: apiKey}
+}
+
+type Model string
+
+const (
+ DALLE2 = Model("dall-e-2")
+ DALLE3 = Model("dall-e-3")
+)
+
+type Size string
+
+const (
+ Size256 = Size("256x256")
+ Size512 = Size("512x512")
+ Size1024 = Size("1024x1024")
+
+ // These are only supported with the dall-e-3 model.
+ SizeHDWide = Size("1792x1024")
+ SizeHDTall = Size("1024x1792")
+)
+
+type Style string
+
+const (
+ StyleVivid = Style("vivid")
+ StyleNatural = Style("natural")
+)
+
+type Options struct {
+ Model Model `json:"model"`
+ Prompt string `json:"prompt"`
+ N *int `json:"n,omitempty"`
+ Quality *string `json:"quality,omitempty"`
+ Size *Size `json:"size"`
+ Style *Style `json:"style,omitempty"`
+ User *string `json:"user,omitempty"`
+}
+
+type Image struct {
+ URL string `json:"url"`
+}
+
+func (i Image) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("url", i.URL),
+ )
+}
+
+type Response struct {
+ Created int `json:"created"`
+ Data []Image `json:"data"`
+}
+
+func (r Response) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.Int("created", r.Created),
+ slog.Any("data", r.Data),
+ )
+}
+
+func (c Client) Do(req *http.Request) (*http.Response, error) {
+ req.Header.Set("Authorization", "Bearer "+c.apiKey)
+ return http.DefaultClient.Do(req)
+}
+
+func (c Client) GenerateImage(ctx context.Context, opts Options) (*Response, error) {
+ buf := bytes.NewBuffer(nil)
+ if err := json.NewEncoder(buf).Encode(opts); err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/images/generations", buf)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var result Response
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}