diff options
| author | Xe Iaso <me@xeiaso.net> | 2024-08-11 13:30:50 -0400 |
|---|---|---|
| committer | Xe Iaso <me@xeiaso.net> | 2024-08-11 13:30:50 -0400 |
| commit | fdbc967546e8e100ff05c1b6690575693de5f5e6 (patch) | |
| tree | 5845311a22ec92525c8ad7516d2cb87aae9e1c45 /web | |
| parent | 0f32938f2cee6bf92bb8c29417d12883a311af26 (diff) | |
| download | x-fdbc967546e8e100ff05c1b6690575693de5f5e6.tar.xz x-fdbc967546e8e100ff05c1b6690575693de5f5e6.zip | |
add flux client
Signed-off-by: Xe Iaso <me@xeiaso.net>
Diffstat (limited to 'web')
| -rw-r--r-- | web/flux/example_test.go | 60 | ||||
| -rw-r--r-- | web/flux/flux.go | 192 | ||||
| -rw-r--r-- | web/flux/flux_test.go | 33 |
3 files changed, 285 insertions, 0 deletions
diff --git a/web/flux/example_test.go b/web/flux/example_test.go new file mode 100644 index 0000000..48535e4 --- /dev/null +++ b/web/flux/example_test.go @@ -0,0 +1,60 @@ +package flux + +import ( + "fmt" + "io/ioutil" + "time" +) + +func Example() { + client := NewClient("http://xe-flux.flycast") + + if _, err := client.HealthCheck(); err != nil { + fmt.Println("Error health checking:", err) + panic(err) + } + + // Example of using the Predict method + predictionReq := PredictionRequest{ + Input: Input{ + Prompt: "A beautiful sunrise over the mountains", + NumOutputs: 1, + GuidanceScale: 7.5, + MaxSequenceLength: 256, + NumInferenceSteps: 50, + PromptStrength: 0.8, + OutputFormat: "png", + OutputQuality: 90, + }, + ID: "example-prediction-id", + CreatedAt: time.Now().Format(time.RFC3339), + OutputFilePrefix: "output", + Webhook: "http://example.com/webhook", + WebhookEventsFilter: []string{"start", "completed"}, + } + + predictionResp, err := client.Predict(predictionReq) + if err != nil { + fmt.Println("Error predicting:", err) + } else { + fmt.Println("PredictionResponse:", predictionResp) + } + + // Example of using the PredictIdempotent method + predictionResp, err = client.PredictIdempotent("example-prediction-id", predictionReq) + if err != nil { + fmt.Println("Error predicting idempotent:", err) + } else { + fmt.Println("PredictIdempotentResponse:", predictionResp) + } + + // Example of using the CancelPrediction method + resp, err := client.CancelPrediction("example-prediction-id") + if err != nil { + fmt.Println("Error cancelling prediction:", err) + } else { + body, _ := ioutil.ReadAll(resp.Body) + fmt.Println("CancelPrediction response:", string(body)) + resp.Body.Close() + } +} diff --git a/web/flux/flux.go b/web/flux/flux.go new file mode 100644 index 0000000..e1fbed3 --- /dev/null +++ b/web/flux/flux.go @@ -0,0 +1,192 @@ +package flux + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "time" + + "within.website/x/web" +) + +// Struct definitions based on the OpenAPI schema + +type Input struct { + Prompt string `json:"prompt"` + Image string `json:"image,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + NumOutputs int `json:"num_outputs"` + GuidanceScale float64 `json:"guidance_scale"` + MaxSequenceLength int `json:"max_sequence_length"` + NumInferenceSteps int `json:"num_inference_steps"` + PromptStrength float64 `json:"prompt_strength"` + Seed *int `json:"seed,omitempty"` + OutputFormat string `json:"output_format"` + OutputQuality int `json:"output_quality"` +} + +type Output []string + +type PredictionRequest struct { + Input Input `json:"input"` + ID string `json:"id"` + CreatedAt string `json:"created_at"` + OutputFilePrefix string `json:"output_file_prefix"` + Webhook string `json:"webhook"` + WebhookEventsFilter []string `json:"webhook_events_filter"` +} + +type PredictionResponse struct { + Input Input `json:"input"` + Output Output `json:"output"` + ID string `json:"id"` + Version string `json:"version"` + CreatedAt string `json:"created_at"` + StartedAt string `json:"started_at"` + CompletedAt string `json:"completed_at"` + Logs string `json:"logs"` + Error string `json:"error"` + Status string `json:"status"` + Metrics map[string]interface{} `json:"metrics"` +} + +type HTTPValidationError struct { + Detail []ValidationError `json:"detail"` +} + +type ValidationError struct { + Loc []interface{} `json:"loc"` + Msg string `json:"msg"` + Type string `json:"type"` +} + +// HealthCheckResponse represents the response structure for the health check endpoint. +type HealthCheckResponse struct { + Status string `json:"status"` +} + +// Client struct + +type Client struct { + BaseURL string + HTTPClient *http.Client +} + +// NewClient creates a new API client + +func NewClient(baseURL string) *Client { + return &Client{ + BaseURL: baseURL, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + } +} + +// Methods to interact with the API endpoints + +func (c *Client) Predict(predictionReq PredictionRequest) (*PredictionResponse, error) { + body, err := json.Marshal(predictionReq) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/predictions", c.BaseURL), bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var predictionResp PredictionResponse + err = json.NewDecoder(resp.Body).Decode(&predictionResp) + if err != nil { + return nil, err + } + + return &predictionResp, nil +} + +func (c *Client) PredictIdempotent(predictionID string, predictionReq PredictionRequest) (*PredictionResponse, error) { + body, err := json.Marshal(predictionReq) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", fmt.Sprintf("%s/predictions/%s", c.BaseURL, predictionID), bytes.NewBuffer(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + var predictionResp PredictionResponse + err = json.NewDecoder(resp.Body).Decode(&predictionResp) + if err != nil { + return nil, err + } + + return &predictionResp, nil +} + +func (c *Client) CancelPrediction(predictionID string) (*http.Response, error) { + req, err := http.NewRequest("POST", fmt.Sprintf("%s/predictions/%s/cancel", c.BaseURL, predictionID), nil) + if err != nil { + return nil, err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(http.StatusOK, resp) + } + + return resp, nil +} + +// HealthCheck checks the health of the service +func (c *Client) HealthCheck() (*HealthCheckResponse, error) { + req, err := http.NewRequest("GET", fmt.Sprintf("%s/health-check", c.BaseURL), nil) + if err != nil { + return nil, err + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, web.NewError(resp.StatusCode, resp) + } + + var healthResp HealthCheckResponse + err = json.NewDecoder(resp.Body).Decode(&healthResp) + if err != nil { + return nil, err + } + + return &healthResp, nil +} diff --git a/web/flux/flux_test.go b/web/flux/flux_test.go new file mode 100644 index 0000000..2be7212 --- /dev/null +++ b/web/flux/flux_test.go @@ -0,0 +1,33 @@ +package flux + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// Mock server for testing +func mockServer() *httptest.Server { + handler := http.NewServeMux() + handler.HandleFunc("/health-check", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + }) + return httptest.NewServer(handler) +} + +func TestHealthCheck(t *testing.T) { + server := mockServer() + defer server.Close() + + client := NewClient(server.URL) + healthResp, err := client.HealthCheck() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if healthResp.Status != "ok" { + t.Fatalf("expected status 'ok', got %v", healthResp.Status) + } +} |
