aboutsummaryrefslogtreecommitdiff
path: root/web
diff options
context:
space:
mode:
authorXe Iaso <me@xeiaso.net>2024-08-11 13:30:50 -0400
committerXe Iaso <me@xeiaso.net>2024-08-11 13:30:50 -0400
commitfdbc967546e8e100ff05c1b6690575693de5f5e6 (patch)
tree5845311a22ec92525c8ad7516d2cb87aae9e1c45 /web
parent0f32938f2cee6bf92bb8c29417d12883a311af26 (diff)
downloadx-fdbc967546e8e100ff05c1b6690575693de5f5e6.tar.xz
x-fdbc967546e8e100ff05c1b6690575693de5f5e6.zip
add flux client
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
-rw-r--r--web/flux/example_test.go60
-rw-r--r--web/flux/flux.go192
-rw-r--r--web/flux/flux_test.go33
3 files changed, 285 insertions, 0 deletions
diff --git a/web/flux/example_test.go b/web/flux/example_test.go
new file mode 100644
index 0000000..48535e4
--- /dev/null
+++ b/web/flux/example_test.go
@@ -0,0 +1,60 @@
+package flux
+
+import (
+ "fmt"
+ "io/ioutil"
+ "time"
+)
+
+func Example() {
+ client := NewClient("http://xe-flux.flycast")
+
+ if _, err := client.HealthCheck(); err != nil {
+ fmt.Println("Error health checking:", err)
+ panic(err)
+ }
+
+ // Example of using the Predict method
+ predictionReq := PredictionRequest{
+ Input: Input{
+ Prompt: "A beautiful sunrise over the mountains",
+ NumOutputs: 1,
+ GuidanceScale: 7.5,
+ MaxSequenceLength: 256,
+ NumInferenceSteps: 50,
+ PromptStrength: 0.8,
+ OutputFormat: "png",
+ OutputQuality: 90,
+ },
+ ID: "example-prediction-id",
+ CreatedAt: time.Now().Format(time.RFC3339),
+ OutputFilePrefix: "output",
+ Webhook: "http://example.com/webhook",
+ WebhookEventsFilter: []string{"start", "completed"},
+ }
+
+ predictionResp, err := client.Predict(predictionReq)
+ if err != nil {
+ fmt.Println("Error predicting:", err)
+ } else {
+ fmt.Println("PredictionResponse:", predictionResp)
+ }
+
+ // Example of using the PredictIdempotent method
+ predictionResp, err = client.PredictIdempotent("example-prediction-id", predictionReq)
+ if err != nil {
+ fmt.Println("Error predicting idempotent:", err)
+ } else {
+ fmt.Println("PredictIdempotentResponse:", predictionResp)
+ }
+
+ // Example of using the CancelPrediction method
+ resp, err := client.CancelPrediction("example-prediction-id")
+ if err != nil {
+ fmt.Println("Error cancelling prediction:", err)
+ } else {
+ body, _ := ioutil.ReadAll(resp.Body)
+ fmt.Println("CancelPrediction response:", string(body))
+ resp.Body.Close()
+ }
+}
diff --git a/web/flux/flux.go b/web/flux/flux.go
new file mode 100644
index 0000000..e1fbed3
--- /dev/null
+++ b/web/flux/flux.go
@@ -0,0 +1,192 @@
+package flux
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+
+ "within.website/x/web"
+)
+
+// Struct definitions based on the OpenAPI schema
+
+type Input struct {
+ Prompt string `json:"prompt"`
+ Image string `json:"image,omitempty"`
+ AspectRatio string `json:"aspect_ratio,omitempty"`
+ NumOutputs int `json:"num_outputs"`
+ GuidanceScale float64 `json:"guidance_scale"`
+ MaxSequenceLength int `json:"max_sequence_length"`
+ NumInferenceSteps int `json:"num_inference_steps"`
+ PromptStrength float64 `json:"prompt_strength"`
+ Seed *int `json:"seed,omitempty"`
+ OutputFormat string `json:"output_format"`
+ OutputQuality int `json:"output_quality"`
+}
+
+type Output []string
+
+type PredictionRequest struct {
+ Input Input `json:"input"`
+ ID string `json:"id"`
+ CreatedAt string `json:"created_at"`
+ OutputFilePrefix string `json:"output_file_prefix"`
+ Webhook string `json:"webhook"`
+ WebhookEventsFilter []string `json:"webhook_events_filter"`
+}
+
+type PredictionResponse struct {
+ Input Input `json:"input"`
+ Output Output `json:"output"`
+ ID string `json:"id"`
+ Version string `json:"version"`
+ CreatedAt string `json:"created_at"`
+ StartedAt string `json:"started_at"`
+ CompletedAt string `json:"completed_at"`
+ Logs string `json:"logs"`
+ Error string `json:"error"`
+ Status string `json:"status"`
+ Metrics map[string]interface{} `json:"metrics"`
+}
+
+type HTTPValidationError struct {
+ Detail []ValidationError `json:"detail"`
+}
+
+type ValidationError struct {
+ Loc []interface{} `json:"loc"`
+ Msg string `json:"msg"`
+ Type string `json:"type"`
+}
+
+// HealthCheckResponse represents the response structure for the health check endpoint.
+type HealthCheckResponse struct {
+ Status string `json:"status"`
+}
+
+// Client struct
+
+type Client struct {
+ BaseURL string
+ HTTPClient *http.Client
+}
+
+// NewClient creates a new API client
+
+func NewClient(baseURL string) *Client {
+ return &Client{
+ BaseURL: baseURL,
+ HTTPClient: &http.Client{Timeout: 10 * time.Second},
+ }
+}
+
+// Methods to interact with the API endpoints
+
+func (c *Client) Predict(predictionReq PredictionRequest) (*PredictionResponse, error) {
+ body, err := json.Marshal(predictionReq)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/predictions", c.BaseURL), bytes.NewBuffer(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.HTTPClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var predictionResp PredictionResponse
+ err = json.NewDecoder(resp.Body).Decode(&predictionResp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &predictionResp, nil
+}
+
+func (c *Client) PredictIdempotent(predictionID string, predictionReq PredictionRequest) (*PredictionResponse, error) {
+ body, err := json.Marshal(predictionReq)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequest("PUT", fmt.Sprintf("%s/predictions/%s", c.BaseURL, predictionID), bytes.NewBuffer(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.HTTPClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ var predictionResp PredictionResponse
+ err = json.NewDecoder(resp.Body).Decode(&predictionResp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &predictionResp, nil
+}
+
+func (c *Client) CancelPrediction(predictionID string) (*http.Response, error) {
+ req, err := http.NewRequest("POST", fmt.Sprintf("%s/predictions/%s/cancel", c.BaseURL, predictionID), nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := c.HTTPClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(http.StatusOK, resp)
+ }
+
+ return resp, nil
+}
+
+// HealthCheck checks the health of the service
+func (c *Client) HealthCheck() (*HealthCheckResponse, error) {
+ req, err := http.NewRequest("GET", fmt.Sprintf("%s/health-check", c.BaseURL), nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := c.HTTPClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, web.NewError(resp.StatusCode, resp)
+ }
+
+ var healthResp HealthCheckResponse
+ err = json.NewDecoder(resp.Body).Decode(&healthResp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &healthResp, nil
+}
diff --git a/web/flux/flux_test.go b/web/flux/flux_test.go
new file mode 100644
index 0000000..2be7212
--- /dev/null
+++ b/web/flux/flux_test.go
@@ -0,0 +1,33 @@
+package flux
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+// Mock server for testing
+func mockServer() *httptest.Server {
+ handler := http.NewServeMux()
+ handler.HandleFunc("/health-check", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{"status":"ok"}`))
+ })
+ return httptest.NewServer(handler)
+}
+
+func TestHealthCheck(t *testing.T) {
+ server := mockServer()
+ defer server.Close()
+
+ client := NewClient(server.URL)
+ healthResp, err := client.HealthCheck()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if healthResp.Status != "ok" {
+ t.Fatalf("expected status 'ok', got %v", healthResp.Status)
+ }
+}