diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-06-22 07:46:14 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-06-22 07:46:14 -0400 |
| commit | 0f1f7935ae53639c0716f76a637aa0a07be0fed3 (patch) | |
| tree | 723b446ee13574bebe3c80e28ae90fcec4e58e84 /web | |
| parent | f8251f93d4a500e6a5b82bdb045cd5e0c51a7140 (diff) | |
| download | x-0f1f7935ae53639c0716f76a637aa0a07be0fed3.tar.xz x-0f1f7935ae53639c0716f76a637aa0a07be0fed3.zip | |
internal/stablediffusion: promote to public package
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
| -rw-r--r-- | web/stablediffusion/doc.go | 2 | ||||
| -rw-r--r-- | web/stablediffusion/stablediffusion.go | 138 |
2 files changed, 140 insertions, 0 deletions
diff --git a/web/stablediffusion/doc.go b/web/stablediffusion/doc.go new file mode 100644 index 0000000..9bbd3a3 --- /dev/null +++ b/web/stablediffusion/doc.go @@ -0,0 +1,2 @@ +// Package stablediffusion provides a simple API client for the Automatic1111 Stable Diffusion web UI. +package stablediffusion diff --git a/web/stablediffusion/stablediffusion.go b/web/stablediffusion/stablediffusion.go new file mode 100644 index 0000000..a41f315 --- /dev/null +++ b/web/stablediffusion/stablediffusion.go @@ -0,0 +1,138 @@ +package stablediffusion + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "net/http" + "net/url" + "sync" + + "within.website/x/web" +) + +var ( + sdServerURL = flag.String("within.website/x/web/stablediffusion-server-url", "http://logos:7860", "URL for the Stable Diffusion API used with the default client") +) + +func buildURL(base, path string) (*url.URL, error) { + u, err := url.Parse(base) + if err != nil { + return nil, err + } + + u.Path = path + + return u, nil +} + +// SimpleImageRequest is all of the parameters needed to generate an image. +type SimpleImageRequest struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt"` + Styles []string `json:"styles"` + Seed int `json:"seed"` + SamplerName string `json:"sampler_name"` + BatchSize int `json:"batch_size"` + NIter int `json:"n_iter"` + Steps int `json:"steps"` + CfgScale int `json:"cfg_scale"` + Width int `json:"width"` + Height int `json:"height"` + SNoise int `json:"s_noise"` + OverrideSettings struct { + } `json:"override_settings"` + OverrideSettingsRestoreAfterwards bool `json:"override_settings_restore_afterwards"` +} + +type ImageResponse struct { + Images [][]byte `json:"images"` + Info string `json:"info"` +} + +type ImageInfo struct { + Prompt string `json:"prompt"` + AllPrompts []string `json:"all_prompts"` + NegativePrompt string `json:"negative_prompt"` + AllNegativePrompts []string `json:"all_negative_prompts"` + Seed int `json:"seed"` + AllSeeds []int `json:"all_seeds"` + Subseed int `json:"subseed"` + AllSubseeds []int `json:"all_subseeds"` + SubseedStrength int `json:"subseed_strength"` + Width int `json:"width"` + Height int `json:"height"` + SamplerName string `json:"sampler_name"` + CfgScale float64 `json:"cfg_scale"` + Steps int `json:"steps"` + BatchSize int `json:"batch_size"` + RestoreFaces bool `json:"restore_faces"` + FaceRestorationModel interface{} `json:"face_restoration_model"` + SdModelHash string `json:"sd_model_hash"` + SeedResizeFromW int `json:"seed_resize_from_w"` + SeedResizeFromH int `json:"seed_resize_from_h"` + DenoisingStrength int `json:"denoising_strength"` + ExtraGenerationParams struct { + } `json:"extra_generation_params"` + IndexOfFirstImage int `json:"index_of_first_image"` + Infotexts []string `json:"infotexts"` + Styles []interface{} `json:"styles"` + JobTimestamp string `json:"job_timestamp"` + ClipSkip int `json:"clip_skip"` + IsUsingInpaintingConditioning bool `json:"is_using_inpainting_conditioning"` +} + +var ( + Default *Client = &Client{ + HTTP: http.DefaultClient, + } + lock sync.Mutex +) + +func Generate(ctx context.Context, inp SimpleImageRequest) (*ImageResponse, error) { + lock.Lock() + Default.APIServer = *sdServerURL + lock.Unlock() + return Default.Generate(ctx, inp) +} + +type Client struct { + HTTP *http.Client + APIServer string +} + +func (c *Client) Generate(ctx context.Context, inp SimpleImageRequest) (*ImageResponse, error) { + u, err := buildURL(c.APIServer, "/sdapi/v1/txt2img") + if err != nil { + return nil, fmt.Errorf("error building URL: %w", err) + } + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(inp); err != nil { + return nil, fmt.Errorf("error encoding json: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), &buf) + if err != nil { + return nil, fmt.Errorf("error making request: %w", err) + } + + resp, err := c.HTTP.Do(req) + if err != nil { + return nil, fmt.Errorf("error fetching response: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var result ImageResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error parsing ImageResponse: %w", err) + } + + return &result, nil +} |
