diff options
| author | Xe Iaso <me@xeiaso.net> | 2023-03-04 15:53:43 -0500 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2023-03-04 15:53:43 -0500 |
| commit | ea95496b1d9c9637bad0159b04e781daab359f1d (patch) | |
| tree | 42ea127740f0584315f6325d81891166dfc497ac /cmd | |
| parent | f9323c9d33ac55d2a4fcd1f6dbee350cc3416c95 (diff) | |
| download | x-ea95496b1d9c9637bad0159b04e781daab359f1d.tar.xz x-ea95496b1d9c9637bad0159b04e781daab359f1d.zip | |
making mistakes
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'cmd')
| -rw-r--r-- | cmd/xedn/.gitignore | 1 | ||||
| -rw-r--r-- | cmd/xedn/main.go | 24 | ||||
| -rw-r--r-- | cmd/xedn/stablediffusion.go | 305 |
3 files changed, 323 insertions, 7 deletions
diff --git a/cmd/xedn/.gitignore b/cmd/xedn/.gitignore index 17b2f35..4a13387 100644 --- a/cmd/xedn/.gitignore +++ b/cmd/xedn/.gitignore @@ -1 +1,2 @@ slug.tar.gz +xedn diff --git a/cmd/xedn/main.go b/cmd/xedn/main.go index 7591610..eff055a 100644 --- a/cmd/xedn/main.go +++ b/cmd/xedn/main.go @@ -32,6 +32,7 @@ import ( "within.website/ln/ex" "within.website/ln/opname" "within.website/x/internal" + "within.website/x/internal/stablediffusion" "within.website/x/web" ) @@ -376,14 +377,22 @@ func main() { group: &singleflight.Group{}, } - go func() { - srv := &tsnet.Server{ - Hostname: "xedn-" + os.Getenv("FLY_REGION"), - Logf: log.New(io.Discard, "", 0).Printf, - AuthKey: os.Getenv("TS_AUTHKEY"), - Dir: filepath.Join(*dir, "tsnet"), - } + srv := &tsnet.Server{ + Hostname: "xedn-" + os.Getenv("FLY_REGION"), + Logf: log.New(io.Discard, "", 0).Printf, + AuthKey: os.Getenv("TS_AUTHKEY"), + Dir: filepath.Join(*dir, "tsnet"), + } + + cli := srv.HTTPClient() + sd := &StableDiffusion{ + db: db, + client: &stablediffusion.Client{HTTP: cli}, + group: &singleflight.Group{}, + } + + go func() { lis, err := srv.Listen("tcp", ":80") if err != nil { ln.FatalErr(ctx, err, ln.Action("tsnet listening")) @@ -423,6 +432,7 @@ func main() { }) mux.Handle("/sticker/", ois) + mux.Handle("/avatar/", sd) hdlr := func(w http.ResponseWriter, r *http.Request) { etagLock.RLock() diff --git a/cmd/xedn/stablediffusion.go b/cmd/xedn/stablediffusion.go index 06ab7d0..6bc18c0 100644 --- a/cmd/xedn/stablediffusion.go +++ b/cmd/xedn/stablediffusion.go @@ -1 +1,306 @@ package main + +import ( + "bytes" + "context" + "fmt" + "image" + "image/jpeg" + "math/rand" + "net/http" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "go.etcd.io/bbolt" + "golang.org/x/sync/singleflight" + "within.website/ln" + "within.website/x/internal/stablediffusion" +) + +type StableDiffusion struct { + client *stablediffusion.Client + db *bbolt.DB + group *singleflight.Group +} + +// RenderImage renders a stable diffusion image based on the hash given. +// +// It assumes that the image does not exist, if it does, you will need +// to check elsewhere. +func (sd *StableDiffusion) RenderImage(ctx context.Context, w http.ResponseWriter, hash string) error { + prompt, seed := hallucinatePrompt(hash) + + ln.Log(ctx, ln.Info("generating new image"), ln.F{"prompt": prompt}) + + imgsVal, err, _ := sd.group.Do(hash, func() (interface{}, error) { + imgs, err := sd.client.Generate(ctx, stablediffusion.SimpleImageRequest{ + Prompt: "masterpiece, best quality, " + prompt, + NegativePrompt: "person in distance, worst quality, low quality, medium quality, deleted, lowres, comic, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry", + Seed: seed, + SamplerName: "DPM++ 2M Karras", + BatchSize: 1, + NIter: 1, + Steps: 20, + CfgScale: 7, + Width: 256, + Height: 256, + SNoise: 1, + + OverrideSettingsRestoreAfterwards: true, + }) + if err != nil { + return nil, err + } + + img, _, err := image.Decode(bytes.NewBuffer(imgs.Images[0])) + if err != nil { + return nil, err + } + + buf := &bytes.Buffer{} + + if err := jpeg.Encode(buf, img, &jpeg.Options{Quality: 75}); err != nil { + return nil, err + } + + imgs.Images[0] = buf.Bytes() + + return imgs, nil + }) + if err != nil { + return err + } + imgs := imgsVal.(*stablediffusion.ImageResponse) + + ln.Log(ctx, ln.Info("done generating image"), ln.F{"prompt": prompt}) + + if err := sd.db.Update(func(tx *bbolt.Tx) error { + bkt := tx.Bucket([]byte("avatars")) + + if err := bkt.Put([]byte(hash), []byte(imgs.Images[0])); err != nil { + return err + } + + return nil + }); err != nil { + return err + } + + w.Header().Set("content-type", "image/jpeg") + w.Header().Set("content-length", fmt.Sprint(len(imgs.Images[0]))) + w.Header().Set("expires", time.Now().Add(30*24*time.Hour).Format(http.TimeFormat)) + w.Header().Set("Cache-Control", "max-age:2630000") // one month + w.WriteHeader(http.StatusOK) + w.Write(imgs.Images[0]) + + return nil +} + +var isHexRegex = regexp.MustCompile(`[-fA-F0-9]+$`) + +func (sd *StableDiffusion) ServeHTTP(w http.ResponseWriter, r *http.Request) { + hash := filepath.Base(r.URL.Path) + + if !isHexRegex.MatchString(hash) { + http.Error(w, "the input must be a hexadecimal string", http.StatusBadRequest) + return + } + + if len(hash) != 32 { + http.Error(w, "this must be 32 characters", http.StatusBadRequest) + return + } + + if err := sd.db.Update(func(tx *bbolt.Tx) error { + if _, err := tx.CreateBucketIfNotExists([]byte("avatars")); err != nil { + return err + } + return nil + }); err != nil { + http.Error(w, "can't access database", http.StatusInternalServerError) + return + } + + found := false + + sd.db.View(func(tx *bbolt.Tx) error { + bkt := tx.Bucket([]byte("avatars")) + data := bkt.Get([]byte(hash)) + found = data != nil + + if found { + w.Header().Set("content-type", "image/png") + w.Header().Set("content-length", fmt.Sprint(len(data))) + w.Header().Set("expires", time.Now().Add(30*24*time.Hour).Format(http.TimeFormat)) + w.Header().Set("Cache-Control", "max-age:2630000") // one month + w.WriteHeader(http.StatusOK) + w.Write(data) + } + + return nil + }) + + if !found { + if err := sd.RenderImage(r.Context(), w, hash); err != nil { + http.Error(w, "cannot render image, sorry", http.StatusInternalServerError) + ln.Error(r.Context(), err) + return + } + } +} + +func hallucinatePrompt(hash string) (string, int) { + var sb strings.Builder + if hash[0] > '0' && hash[0] <= '5' { + fmt.Fprint(&sb, "1girl, ") + } else { + fmt.Fprint(&sb, "1guy, ") + } + + switch hash[1] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "blonde, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "brown hair, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "red hair, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "black hair, ") + default: + } + + if hash[2] > '0' && hash[2] <= '5' { + fmt.Fprint(&sb, "coffee shop, ") + } else { + fmt.Fprint(&sb, "landscape, outdoors, ") + } + + if hash[3] > '0' && hash[3] <= '5' { + fmt.Fprint(&sb, "hoodie, ") + } else { + fmt.Fprint(&sb, "sweatsuit, ") + } + + switch hash[4] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "<lora:cdi:1>, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "breath of the wild, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "genshin impact, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "arknights, ") + default: + } + + if hash[5] > '0' && hash[5] <= '5' { + fmt.Fprint(&sb, "watercolor, ") + } else { + fmt.Fprint(&sb, "matte painting, ") + } + + switch hash[6] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "highly detailed, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "ornate, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "thick lines, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "3d render, ") + default: + } + + switch hash[7] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "short hair, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "long hair, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "ponytail, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "pigtails, ") + default: + } + + switch hash[8] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "smile, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "frown, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "laughing, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "angry, ") + default: + } + + switch hash[9] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "sweater, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "tshirt, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "suitjacket, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "armor, ") + default: + } + + switch hash[10] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "blue eyes, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "red eyes, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "brown eyes, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "hazel eyes, ") + default: + } + + if hash[11] == '0' { + fmt.Fprint(&sb, "heterochromia, ") + } + + switch hash[12] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "morning, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "afternoon, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "evening, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "nighttime, ") + default: + } + + if hash[13] == '0' { + fmt.Fprint(&sb, "<lora:genshin:1>, genshin, ") + } + + switch hash[14] { + case '0', '1', '2', '3': + fmt.Fprint(&sb, "vtuber, ") + case '4', '5', '6', '7': + fmt.Fprint(&sb, "anime, ") + case '8', '9', 'a', 'b': + fmt.Fprint(&sb, "studio ghibli, ") + case 'c', 'd', 'e', 'f': + fmt.Fprint(&sb, "cloverworks, ") + default: + } + + seedPortion := hash[len(hash)-9 : len(hash)-1] + seed, err := strconv.ParseInt(seedPortion, 16, 32) + if err != nil { + seed = int64(rand.Int()) + } + + fmt.Fprint(&sb, "pants") + + return sb.String(), int(seed) +} |
