diff --git a/apps/api/go.mod b/apps/api/go.mod index fe401f1..79cc835 100644 --- a/apps/api/go.mod +++ b/apps/api/go.mod @@ -13,6 +13,11 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dlclark/regexp2 v1.11.4 // indirect + github.com/dop251/goja v0.0.0-20260311135729-065cd970411c // indirect + github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14 // indirect + github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect + github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/apps/api/go.sum b/apps/api/go.sum index 26c9282..5ae9733 100644 --- a/apps/api/go.sum +++ b/apps/api/go.sum @@ -1,8 +1,18 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo= +github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dop251/goja v0.0.0-20260311135729-065cd970411c h1:OcLmPfx1T1RmZVHHFwWMPaZDdRf0DBMZOFMVWJa7Pdk= +github.com/dop251/goja v0.0.0-20260311135729-065cd970411c/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4= +github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14 h1:3U8dTgyNBhEQ/GVw0jZW5q+93Zw2gAZPRWhJ9TwV3rM= +github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14/go.mod h1:Tb7Xxye4LX7cT3i8YLvmPMGCV92IOi4CDZvm/V8ylc0= +github.com/go-sourcemap/sourcemap v2.1.4+incompatible h1:a+iTbH5auLKxaNwQFg0B+TCYl6lbukKPc7b5x0n1s6Q= +github.com/go-sourcemap/sourcemap v2.1.4+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= +github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= diff --git a/apps/api/internal/clients/media_clients.go b/apps/api/internal/clients/media_clients.go new file mode 100644 index 0000000..f93a0da --- /dev/null +++ b/apps/api/internal/clients/media_clients.go @@ -0,0 +1,232 @@ +package clients + +import ( + "context" + "net/http" + "strings" +) + +type JimengClient struct{ HTTPClient *http.Client } +type BlackforestClient struct{ HTTPClient *http.Client } +type HunyuanImageClient struct{ HTTPClient *http.Client } +type HunyuanVideoClient struct{ HTTPClient *http.Client } +type MinimaxClient struct{ HTTPClient *http.Client } +type MidjourneyClient struct{ HTTPClient *http.Client } +type ViduClient struct{ HTTPClient *http.Client } +type AliyunBailianClient struct{ HTTPClient *http.Client } +type NewAPIClient struct{ HTTPClient *http.Client } + +func (c JimengClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: jimengSpec()}.Run(ctx, request) +} + +func (c BlackforestClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: blackforestSpec()}.Run(ctx, request) +} + +func (c HunyuanImageClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanImageSpec()}.Run(ctx, request) +} + +func (c HunyuanVideoClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanVideoSpec()}.Run(ctx, request) +} + +func (c MinimaxClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: minimaxSpec()}.Run(ctx, request) +} + +func (c MidjourneyClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: midjourneySpec()}.Run(ctx, request) +} + +func (c ViduClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: viduSpec()}.Run(ctx, request) +} + +func (c AliyunBailianClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: aliyunBailianSpec()}.Run(ctx, request) +} + +func (c NewAPIClient) Run(ctx context.Context, request Request) (Response, error) { + return providerTaskClient{HTTPClient: c.HTTPClient, Spec: newAPISpec()}.Run(ctx, request) +} + +func jimengSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "jimeng", + SubmitPath: func(request Request, _ map[string]any) string { + return configuredPath(request, "?Action=CVSubmitTask&Version=2022-08-31", "submitPath", "submit_path") + }, + PollPath: func(request Request, _ string, _ map[string]any) string { + return configuredPath(request, "?Action=CVSync2AsyncGetResult&Version=2022-08-31", "pollPath", "poll_path") + }, + Auth: "bearer", + TaskIDPaths: []string{"data.task_id"}, + StatusPaths: []string{"data.status"}, + SuccessStatuses: []string{"done"}, + DefaultSubmitBody: func(request Request, body map[string]any) map[string]any { + body["req_key"] = upstreamModelName(request.Candidate) + if body["prompt"] == nil { + body["prompt"] = mediaPromptText(body) + } + return body + }, + } +} + +func blackforestSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "blackforest", + SubmitPath: func(request Request, body map[string]any) string { + return configuredPath(request, "/"+upstreamModelName(request.Candidate), "submitPath", "submit_path") + }, + PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return upstreamTaskID }, + Auth: "x-key", + TaskIDPaths: []string{"polling_url"}, + StatusPaths: []string{"status"}, + SuccessStatuses: []string{"ready"}, + FailureStatuses: []string{"error", "task not found"}, + } +} + +func hunyuanImageSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "tencent-hunyuan-image", + SubmitPath: func(request Request, _ map[string]any) string { + return configuredPath(request, "?Action=SubmitHunyuanImageJob&Version=2023-09-01", "submitPath", "submit_path") + }, + PollPath: func(request Request, _ string, _ map[string]any) string { + return configuredPath(request, "?Action=QueryHunyuanImageJob&Version=2023-09-01&JobId=${taskId}", "pollPath", "poll_path") + }, + Auth: "bearer", + TaskIDPaths: []string{"Response.JobId"}, + StatusPaths: []string{"Response.Status"}, + SuccessStatuses: []string{"done"}, + FailureStatuses: []string{"fail"}, + DefaultSubmitBody: func(request Request, body map[string]any) map[string]any { + body["Prompt"] = mediaPromptText(body) + body["Model"] = upstreamModelName(request.Candidate) + return body + }, + } +} + +func hunyuanVideoSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "tencent-hunyuan-video", + SubmitPath: func(request Request, _ map[string]any) string { + return configuredPath(request, "?Action=SubmitTextToVideoJob&Version=2024-01-01", "submitPath", "submit_path") + }, + PollPath: func(request Request, _ string, _ map[string]any) string { + return configuredPath(request, "?Action=QueryVideoJob&Version=2024-01-01&JobId=${taskId}", "pollPath", "poll_path") + }, + Auth: "bearer", + TaskIDPaths: []string{"Response.JobId"}, + StatusPaths: []string{"Response.Status"}, + SuccessStatuses: []string{"done"}, + FailureStatuses: []string{"fail"}, + DefaultSubmitBody: func(request Request, body map[string]any) map[string]any { + body["Prompt"] = mediaPromptText(body) + body["Model"] = upstreamModelName(request.Candidate) + return body + }, + } +} + +func minimaxSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "minimax", + SubmitPath: func(Request, map[string]any) string { return "/video_generation" }, + PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { + return "/query/video_generation?task_id=" + upstreamTaskID + }, + Auth: "bearer", + TaskIDPaths: []string{"task_id"}, + StatusPaths: []string{"status"}, + SuccessStatuses: []string{"success"}, + FailureStatuses: []string{"failed", "expired"}, + } +} + +func midjourneySpec() providerTaskSpec { + return providerTaskSpec{ + Name: "midjourney", + SubmitPath: func(request Request, body map[string]any) string { + return configuredPath(request, "/diffusion", "submitPath", "submit_path") + }, + PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/job/" + upstreamTaskID }, + Auth: "bearer", + TaskIDPaths: []string{"job_id", "id"}, + StatusPaths: []string{"status"}, + SuccessStatuses: []string{"success", "completed"}, + FailureStatuses: []string{"failed"}, + DefaultSubmitBody: func(request Request, body map[string]any) map[string]any { + if body["prompt"] == nil && body["text"] == nil { + body["prompt"] = mediaPromptText(body) + } + return body + }, + } +} + +func viduSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "vidu", + SubmitPath: func(request Request, body map[string]any) string { + if path := configuredPath(request, "", "submitPath", "submit_path"); path != "" { + return path + } + taskType := firstNonEmptyString(body["type"], body["task_type"], "text2video") + if taskType == "multiframe" { + return "/multiframe" + } + return "/" + taskType + }, + PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { + return "/tasks/" + upstreamTaskID + "/creations" + }, + Auth: "token", + TaskIDPaths: []string{"task_id"}, + StatusPaths: []string{"state", "status"}, + SuccessStatuses: []string{"success", "succeeded"}, + FailureStatuses: []string{"failed"}, + } +} + +func aliyunBailianSpec() providerTaskSpec { + return providerTaskSpec{ + Name: "aliyun-bailian", + SubmitPath: func(Request, map[string]any) string { return "/services/aigc/video-generation/video-synthesis" }, + PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/tasks/" + upstreamTaskID }, + Auth: "bearer", + TaskIDPaths: []string{"output.task_id"}, + StatusPaths: []string{"output.task_status"}, + SuccessStatuses: []string{"succeeded", "success"}, + FailureStatuses: []string{"failed"}, + } +} + +func newAPISpec() providerTaskSpec { + return providerTaskSpec{ + Name: "newapi", + SubmitPath: func(Request, map[string]any) string { return "/videos/generations" }, + PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { + return "/videos/generations/" + upstreamTaskID + }, + Auth: "bearer", + TaskIDPaths: []string{"task_id"}, + StatusPaths: []string{"status"}, + SuccessStatuses: []string{"success"}, + FailureStatuses: []string{"failure", "failed"}, + } +} + +func configuredPath(request Request, fallback string, keys ...string) string { + for _, key := range keys { + if value := strings.TrimSpace(stringFromAny(request.Candidate.PlatformConfig[key])); value != "" { + return value + } + } + return fallback +} diff --git a/apps/api/internal/clients/provider_task.go b/apps/api/internal/clients/provider_task.go new file mode 100644 index 0000000..d24d717 --- /dev/null +++ b/apps/api/internal/clients/provider_task.go @@ -0,0 +1,453 @@ +package clients + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +type providerTaskSpec struct { + Name string + SubmitPath func(Request, map[string]any) string + PollPath func(Request, string, map[string]any) string + Auth string + TaskIDPaths []string + StatusPaths []string + SuccessStatuses []string + FailureStatuses []string + ProcessStatuses []string + DefaultSubmitBody func(Request, map[string]any) map[string]any +} + +type providerTaskClient struct { + HTTPClient *http.Client + Spec providerTaskSpec +} + +func (c providerTaskClient) Run(ctx context.Context, request Request) (Response, error) { + if request.Kind != "images.generations" && request.Kind != "images.edits" && request.Kind != "videos.generations" { + return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported " + c.Spec.Name + " request kind", Retryable: false} + } + startedAt := time.Now() + payload := cloneBody(request.Body) + if c.Spec.DefaultSubmitBody != nil { + payload = c.Spec.DefaultSubmitBody(request, payload) + } else { + payload["model"] = upstreamModelName(request.Candidate) + } + + upstreamTaskID := strings.TrimSpace(request.RemoteTaskID) + requestID := upstreamTaskID + var submitResult map[string]any + if upstreamTaskID == "" { + result, id, err := c.submit(ctx, request, payload) + if err != nil { + return Response{}, annotateResponseError(err, id, startedAt, time.Now()) + } + submitResult = result + requestID = firstNonEmptyString(id, requestIDFromResult(result)) + if isProviderTaskFailure(c.Spec, result) { + return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt) + } + if isProviderTaskSuccess(c.Spec, result) && hasProviderTaskResult(result) { + return Response{ + Result: normalizeProviderTaskResult(request, c.Spec, result, ""), + RequestID: requestID, + Progress: providerProgress(request), + ResponseStartedAt: startedAt, + ResponseFinishedAt: time.Now(), + ResponseDurationMS: responseDurationMS(startedAt, time.Now()), + }, nil + } + upstreamTaskID = providerTaskID(c.Spec, result) + if upstreamTaskID == "" { + return Response{}, &ClientError{Code: "invalid_response", Message: c.Spec.Name + " task id is missing", RequestID: requestID, Retryable: false} + } + if request.OnRemoteTaskSubmitted != nil { + if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil { + return Response{}, err + } + } + } else if request.RemoteTaskPayload != nil { + if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok { + payload = existingPayload + } + } + + interval := providerPollInterval(request) + timeout := providerPollTimeout(request) + deadline := time.NewTimer(timeout) + defer deadline.Stop() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + var lastResult map[string]any + for { + pollStarted := time.Now() + result, pollRequestID, err := c.poll(ctx, request, upstreamTaskID, payload) + pollFinished := time.Now() + if err != nil { + return Response{}, annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished) + } + lastResult = result + requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID) + if isProviderTaskSuccess(c.Spec, result) { + finishedAt := time.Now() + return Response{ + Result: normalizeProviderTaskResult(request, c.Spec, result, upstreamTaskID), + RequestID: requestID, + Progress: append(providerProgress(request), Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}), + ResponseStartedAt: startedAt, + ResponseFinishedAt: finishedAt, + ResponseDurationMS: responseDurationMS(startedAt, finishedAt), + }, nil + } + if isProviderTaskFailure(c.Spec, result) { + return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt) + } + select { + case <-ctx.Done(): + return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true} + case <-deadline.C: + return Response{}, &ClientError{Code: "timeout", Message: fmt.Sprintf("%s task %s did not finish before timeout; last status: %s", c.Spec.Name, upstreamTaskID, providerTaskStatus(c.Spec, lastResult)), RequestID: requestID, Retryable: true} + case <-ticker.C: + } + } +} + +func (c providerTaskClient) submit(ctx context.Context, request Request, payload map[string]any) (map[string]any, string, error) { + path := c.Spec.SubmitPath(request, payload) + return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, path), payload, request.Candidate.Credentials, c.Spec.Auth) +} + +func (c providerTaskClient) poll(ctx context.Context, request Request, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) { + path := resolveProviderPathTemplate(c.Spec.PollPath(request, upstreamTaskID, payload), upstreamTaskID) + url := path + if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") { + url = providerURL(request.Candidate.BaseURL, path) + } + if c.Spec.Name == "jimeng" { + body := map[string]any{"task_id": upstreamTaskID, "req_key": upstreamModelName(request.Candidate)} + return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, body, request.Candidate.Credentials, c.Spec.Auth) + } + return providerGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, request.Candidate.Credentials, c.Spec.Auth) +} + +func providerPostJSON(ctx context.Context, client *http.Client, url string, body map[string]any, credentials map[string]any, auth string) (map[string]any, string, error) { + raw, _ := json.Marshal(body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw)) + if err != nil { + return nil, "", err + } + req.Header.Set("Content-Type", "application/json") + applyProviderAuth(req, credentials, auth) + resp, err := client.Do(req) + if err != nil { + return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} + } + requestID := requestIDFromHTTPResponse(resp) + result, err := decodeHTTPResponse(resp) + return result, requestID, err +} + +func providerGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any, auth string) (map[string]any, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, "", err + } + applyProviderAuth(req, credentials, auth) + resp, err := client.Do(req) + if err != nil { + return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} + } + requestID := requestIDFromHTTPResponse(resp) + result, err := decodeHTTPResponse(resp) + return result, requestID, err +} + +func applyProviderAuth(req *http.Request, credentials map[string]any, auth string) { + apiKey := credential(credentials, "apiKey", "api_key", "key", "token") + switch auth { + case "token": + if apiKey != "" { + req.Header.Set("Authorization", "Token "+apiKey) + } + case "x-key": + if apiKey != "" { + req.Header.Set("x-key", apiKey) + } + case "none": + default: + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + } +} + +func providerURL(base string, path string) string { + path = strings.TrimSpace(path) + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path + } + if path == "" { + path = "/" + } + if !strings.HasPrefix(path, "/") && !strings.HasPrefix(path, "?") { + path = "/" + path + } + return joinURL(base, path) +} + +func resolveProviderPathTemplate(path string, upstreamTaskID string) string { + replacements := [][2]string{ + {"${upstream_task_id}", upstreamTaskID}, + {"{{upstream_task_id}}", upstreamTaskID}, + {"{upstream_task_id}", upstreamTaskID}, + {"${task_id}", upstreamTaskID}, + {"{{task_id}}", upstreamTaskID}, + {"{task_id}", upstreamTaskID}, + {"${taskId}", upstreamTaskID}, + {"${taskID}", upstreamTaskID}, + {"{{taskId}}", upstreamTaskID}, + {"{{taskID}}", upstreamTaskID}, + {"{taskId}", upstreamTaskID}, + {"{taskID}", upstreamTaskID}, + } + for _, replacement := range replacements { + path = strings.ReplaceAll(path, replacement[0], replacement[1]) + } + return path +} + +func providerTaskID(spec providerTaskSpec, result map[string]any) string { + paths := append([]string{}, spec.TaskIDPaths...) + paths = append(paths, "task_id", "taskId", "id", "job_id", "Response.JobId", "output.task_id", "data.task_id", "polling_url") + for _, path := range paths { + if value := stringFromPathValue(valueAtPath(result, path)); value != "" { + return value + } + } + return "" +} + +func providerTaskStatus(spec providerTaskSpec, result map[string]any) string { + if result == nil { + return "" + } + if value, ok := valueAtPath(result, "status").(float64); ok { + if value == 2 { + return "success" + } + if value == 3 { + return "failed" + } + return "process" + } + paths := append([]string{}, spec.StatusPaths...) + paths = append(paths, "status", "state", "task_status", "output.task_status", "Response.Status", "data.status") + for _, path := range paths { + if value := stringFromPathValue(valueAtPath(result, path)); value != "" { + return strings.ToLower(value) + } + } + return "" +} + +func stringFromPathValue(value any) string { + if value == nil { + return "" + } + text := strings.TrimSpace(fmt.Sprint(value)) + if text == "" || text == "" { + return "" + } + return text +} + +func isProviderTaskSuccess(spec providerTaskSpec, result map[string]any) bool { + return containsStatus(append([]string{"success", "succeeded", "completed", "complete", "done", "ready", "succeed", "succeeded", "suceeded", "done", "done"}, spec.SuccessStatuses...), providerTaskStatus(spec, result)) +} + +func isProviderTaskFailure(spec providerTaskSpec, result map[string]any) bool { + return containsStatus(append([]string{"failed", "failure", "error", "cancelled", "canceled", "fail", "expired", "task not found"}, spec.FailureStatuses...), providerTaskStatus(spec, result)) +} + +func containsStatus(values []string, status string) bool { + status = strings.ToLower(strings.TrimSpace(status)) + for _, value := range values { + if strings.ToLower(strings.TrimSpace(value)) == status { + return true + } + } + return false +} + +func hasProviderTaskResult(result map[string]any) bool { + return result["data"] != nil || valueAtPath(result, "output.image_urls") != nil || valueAtPath(result, "output.video_url") != nil || valueAtPath(result, "Response.ResultVideoUrl") != nil || valueAtPath(result, "Response.ResultImages") != nil || result["urls"] != nil +} + +func normalizeProviderTaskResult(request Request, spec providerTaskSpec, result map[string]any, upstreamTaskID string) map[string]any { + out := cloneMapAny(result) + out["status"] = "success" + if upstreamTaskID != "" { + out["upstream_task_id"] = upstreamTaskID + } + if out["created"] == nil { + out["created"] = time.Now().UnixMilli() + } + if out["model"] == nil { + out["model"] = request.Model + } + if _, ok := out["data"].([]any); !ok { + if out["data"] != nil { + out["raw_data"] = out["data"] + } + out["data"] = providerTaskData(request, result) + } + return out +} + +func providerTaskData(request Request, result map[string]any) []any { + fileType := "image" + if request.Kind == "videos.generations" || strings.Contains(request.ModelType, "video") { + fileType = "video" + } + urlValues := []any{} + for _, path := range []string{ + "urls", + "image_urls", + "data.image_urls", + "data.images", + "output.image_urls", + "output.video_url", + "output.output", + "data.output", + "data.video_url", + "video_url", + "preview_url", + "Response.ResultImages", + "Response.ResultVideoUrl", + } { + appendURLValues(&urlValues, valueAtPath(result, path)) + } + data := make([]any, 0, len(urlValues)) + for _, raw := range urlValues { + if url := strings.TrimSpace(fmt.Sprint(raw)); url != "" { + data = append(data, map[string]any{"type": fileType, "url": url}) + } + } + if len(data) == 0 { + if base64Values := valueAtPath(result, "data.binary_data_base64"); base64Values != nil { + values := []any{} + appendURLValues(&values, base64Values) + for _, raw := range values { + if content := strings.TrimSpace(fmt.Sprint(raw)); content != "" { + data = append(data, map[string]any{"type": fileType, "content": content, "uploaded": false}) + } + } + } + } + return data +} + +func appendURLValues(out *[]any, value any) { + switch typed := value.(type) { + case nil: + case string: + *out = append(*out, typed) + case []any: + for _, item := range typed { + appendURLValues(out, item) + } + case []string: + for _, item := range typed { + *out = append(*out, item) + } + case map[string]any: + for _, key := range []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "content", "output"} { + if item := strings.TrimSpace(fmt.Sprint(typed[key])); item != "" && item != "" { + *out = append(*out, item) + return + } + } + } +} + +func providerTaskFailure(spec providerTaskSpec, result map[string]any, requestID string, startedAt time.Time) error { + message := firstNonEmptyString(valueAtPath(result, "message"), valueAtPath(result, "error.message"), valueAtPath(result, "error"), valueAtPath(result, "Response.ErrorMessage"), valueAtPath(result, "comment"), spec.Name+" task failed") + return &ClientError{ + Code: firstNonEmptyString(valueAtPath(result, "code"), valueAtPath(result, "error_code"), valueAtPath(result, "Response.ErrorCode"), "provider_failed"), + Message: message, + RequestID: requestID, + ResponseStartedAt: startedAt, + ResponseFinishedAt: time.Now(), + ResponseDurationMS: responseDurationMS(startedAt, time.Now()), + Retryable: false, + } +} + +func providerPollInterval(request Request) time.Duration { + return durationFromConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms") +} + +func providerPollTimeout(request Request) time.Duration { + return durationFromConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs") +} + +func durationFromConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration { + for _, key := range keys { + switch value := config[key].(type) { + case int: + if value > 0 { + return time.Duration(value) * time.Millisecond + } + case int64: + if value > 0 { + return time.Duration(value) * time.Millisecond + } + case float64: + if value > 0 { + return time.Duration(value) * time.Millisecond + } + case string: + if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 { + return parsed + } + } + } + return fallback +} + +func valueAtPath(values map[string]any, path string) any { + if values == nil || strings.TrimSpace(path) == "" { + return nil + } + var current any = values + for _, part := range strings.Split(path, ".") { + object, ok := current.(map[string]any) + if !ok { + return nil + } + current = object[part] + } + return current +} + +func mediaPromptText(body map[string]any) string { + if prompt := strings.TrimSpace(stringFromAny(body["prompt"])); prompt != "" { + return prompt + } + content, _ := body["content"].([]any) + for _, item := range content { + if part, ok := item.(map[string]any); ok && strings.TrimSpace(stringFromAny(part["type"])) == "text" { + if text := strings.TrimSpace(stringFromAny(part["text"])); text != "" { + return text + } + } + } + return "" +} diff --git a/apps/api/internal/clients/provider_task_test.go b/apps/api/internal/clients/provider_task_test.go new file mode 100644 index 0000000..f12fc62 --- /dev/null +++ b/apps/api/internal/clients/provider_task_test.go @@ -0,0 +1,339 @@ +package clients + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestProviderTaskClientsSubmitAndPoll(t *testing.T) { + cases := []struct { + name string + client Client + provider string + specType string + submitMatch func(*http.Request) bool + submitResponse string + pollMatch func(*http.Request) bool + pollResponse string + authHeader string + resultURL string + }{ + { + name: "jimeng", + client: JimengClient{}, + provider: "jimeng", + specType: "jimeng", + submitMatch: func(r *http.Request) bool { + return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask" + }, + submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`, + pollMatch: func(r *http.Request) bool { + return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult" + }, + pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/jimeng.mp4", + }, + { + name: "blackforest", + client: BlackforestClient{}, + provider: "blackforest", + specType: "blackforest", + submitMatch: func(r *http.Request) bool { + return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model") + }, + submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`, + pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" }, + pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`, + authHeader: "test-key", + resultURL: "https://cdn.example/flux.png", + }, + { + name: "hunyuan-image", + client: HunyuanImageClient{}, + provider: "tencent-hunyuan-image", + specType: "tencent-hunyuan-image", + submitMatch: func(r *http.Request) bool { + return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob" + }, + submitResponse: `{"Response":{"JobId":"remote-1"}}`, + pollMatch: func(r *http.Request) bool { + return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob" + }, + pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/hunyuan.png", + }, + { + name: "hunyuan-video", + client: HunyuanVideoClient{}, + provider: "tencent-hunyuan-video", + specType: "tencent-hunyuan-video", + submitMatch: func(r *http.Request) bool { + return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob" + }, + submitResponse: `{"Response":{"JobId":"remote-1"}}`, + pollMatch: func(r *http.Request) bool { + return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob" + }, + pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/hunyuan.mp4", + }, + { + name: "minimax", + client: MinimaxClient{}, + provider: "minimax", + specType: "minimax", + submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" }, + submitResponse: `{"task_id":123}`, + pollMatch: func(r *http.Request) bool { + return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123" + }, + pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/minimax.mp4", + }, + { + name: "midjourney", + client: MidjourneyClient{}, + provider: "midjourney", + specType: "midjourney", + submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" }, + submitResponse: `{"job_id":"remote-1"}`, + pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" }, + pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/mj.png", + }, + { + name: "vidu", + client: ViduClient{}, + provider: "vidu", + specType: "vidu", + submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" }, + submitResponse: `{"task_id":"remote-1"}`, + pollMatch: func(r *http.Request) bool { + return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations" + }, + pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`, + authHeader: "Token test-key", + resultURL: "https://cdn.example/vidu.mp4", + }, + { + name: "aliyun-bailian", + client: AliyunBailianClient{}, + provider: "aliyun-bailian", + specType: "aliyun-bailian", + submitMatch: func(r *http.Request) bool { + return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis" + }, + submitResponse: `{"output":{"task_id":"remote-1"}}`, + pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" }, + pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/aliyun.mp4", + }, + { + name: "newapi", + client: NewAPIClient{}, + provider: "newapi", + specType: "newapi", + submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" }, + submitResponse: `{"task_id":"remote-1"}`, + pollMatch: func(r *http.Request) bool { + return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1" + }, + pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`, + authHeader: "Bearer test-key", + resultURL: "https://cdn.example/newapi.mp4", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tc.authHeader == "test-key" { + if r.Header.Get("x-key") != tc.authHeader { + t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key")) + } + } else if r.Header.Get("Authorization") != tc.authHeader { + t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization")) + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-request-id", "req-"+tc.name) + switch { + case tc.submitMatch(r): + _, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host))) + case tc.pollMatch(r): + _, _ = w.Write([]byte(tc.pollResponse)) + default: + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String()) + } + })) + defer server.Close() + + var submittedRemoteTaskID string + request := Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "alias-model", + Body: map[string]any{"model": "alias-model", "prompt": "hello"}, + Candidate: store.RuntimeModelCandidate{ + Provider: tc.provider, + SpecType: tc.specType, + BaseURL: server.URL, + Credentials: map[string]any{"apiKey": "test-key"}, + PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, + ModelName: "alias-model", + ProviderModelName: "provider-model", + ModelType: "video_generate", + ClientID: tc.name + ":test", + }, + OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error { + submittedRemoteTaskID = remoteTaskID + if payload["payload"] == nil || payload["submit"] == nil { + t.Fatalf("missing remote payload: %#v", payload) + } + return nil + }, + } + + response, err := tc.client.Run(context.Background(), request) + if err != nil { + t.Fatalf("run failed: %v", err) + } + data, ok := response.Result["data"].([]any) + if !ok || len(data) == 0 { + t.Fatalf("missing data: %#v", response.Result) + } + first, _ := data[0].(map[string]any) + if first["url"] != tc.resultURL { + t.Fatalf("unexpected result url: %#v", response.Result) + } + if response.RequestID != "req-"+tc.name { + t.Fatalf("unexpected request id: %q", response.RequestID) + } + if submittedRemoteTaskID == "" { + t.Fatalf("expected remote task submission") + } + }) + } +} + +func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) { + t.Run("poll failure", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-request-id", "req-failed") + switch r.URL.Path { + case "/videos/generations": + _, _ = w.Write([]byte(`{"task_id":"remote-1"}`)) + case "/videos/generations/remote-1": + _, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`)) + default: + t.Fatalf("unexpected request: %s", r.URL.String()) + } + })) + defer server.Close() + + _, err := (NewAPIClient{}).Run(context.Background(), Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "alias-model", + Body: map[string]any{"model": "alias-model", "prompt": "hello"}, + Candidate: store.RuntimeModelCandidate{ + Provider: "newapi", + SpecType: "newapi", + BaseURL: server.URL, + Credentials: map[string]any{"apiKey": "test-key"}, + PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, + ModelName: "alias-model", + ProviderModelName: "provider-model", + ModelType: "video_generate", + }, + }) + var clientErr *ClientError + if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable { + t.Fatalf("expected non-retryable upstream failure, got %#v", err) + } + if clientErr.RequestID != "req-failed" { + t.Fatalf("unexpected request id: %q", clientErr.RequestID) + } + }) + + t.Run("submit rate limit", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("x-request-id", "req-rate-limit") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`)) + })) + defer server.Close() + + _, err := (NewAPIClient{}).Run(context.Background(), Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "alias-model", + Body: map[string]any{"model": "alias-model", "prompt": "hello"}, + Candidate: store.RuntimeModelCandidate{ + Provider: "newapi", + SpecType: "newapi", + BaseURL: server.URL, + Credentials: map[string]any{"apiKey": "test-key"}, + PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, + ModelName: "alias-model", + ProviderModelName: "provider-model", + ModelType: "video_generate", + }, + }) + var clientErr *ClientError + if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" { + t.Fatalf("expected retryable rate limit with request id, got %#v", err) + } + }) +} + +func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + t.Fatal("submit should not run for resumed remote task") + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`)) + })) + defer server.Close() + + response, err := (NewAPIClient{}).Run(context.Background(), Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "alias-model", + Body: map[string]any{"model": "alias-model", "prompt": "hello"}, + RemoteTaskID: "remote-1", + RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}}, + Candidate: store.RuntimeModelCandidate{ + Provider: "newapi", + SpecType: "newapi", + BaseURL: server.URL, + Credentials: map[string]any{"apiKey": "test-key"}, + PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, + ModelName: "alias-model", + ProviderModelName: "provider-model", + ModelType: "video_generate", + }, + }) + if err != nil { + t.Fatalf("run failed: %v", err) + } + data := response.Result["data"].([]any) + first := data[0].(map[string]any) + if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" { + t.Fatalf("unexpected response: %#v", response.Result) + } +} diff --git a/apps/api/internal/clients/universal.go b/apps/api/internal/clients/universal.go new file mode 100644 index 0000000..f2a998a --- /dev/null +++ b/apps/api/internal/clients/universal.go @@ -0,0 +1,481 @@ +package clients + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script" +) + +type UniversalClient struct { + HTTPClient *http.Client + ScriptExecutor *scriptengine.Executor +} + +func (c UniversalClient) Run(ctx context.Context, request Request) (Response, error) { + executor := c.ScriptExecutor + if executor == nil { + executor = &scriptengine.Executor{} + } + startedAt := time.Now() + modelType := strings.TrimSpace(request.ModelType) + if modelType == "" { + modelType = strings.TrimSpace(request.Candidate.ModelType) + } + + payload := cloneBody(request.Body) + upstreamTaskID := strings.TrimSpace(request.RemoteTaskID) + submitRequestID := upstreamTaskID + var submitResult map[string]any + + if upstreamTaskID == "" { + var err error + payload, err = c.universalGetParams(ctx, executor, request, modelType) + if err != nil { + return Response{}, err + } + submitResult, submitRequestID, err = c.universalSubmit(ctx, executor, request, modelType, payload) + if err != nil { + return Response{}, annotateResponseError(err, submitRequestID, startedAt, time.Now()) + } + if isUniversalSuccess(submitResult) && submitResult["data"] != nil { + return Response{ + Result: normalizeUniversalResult(request, submitResult, ""), + RequestID: firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), + Progress: providerProgress(request), + ResponseStartedAt: startedAt, + ResponseFinishedAt: time.Now(), + ResponseDurationMS: responseDurationMS(startedAt, time.Now()), + }, nil + } + if isUniversalFailure(submitResult) { + return Response{}, universalFailureError(submitResult, firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), startedAt) + } + upstreamTaskID = universalTaskID(submitResult) + if upstreamTaskID == "" { + return Response{}, &ClientError{Code: "invalid_response", Message: "universal task id is missing", RequestID: submitRequestID, Retryable: false} + } + if request.OnRemoteTaskSubmitted != nil { + if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil { + return Response{}, err + } + } + } else if request.RemoteTaskPayload != nil { + if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok { + payload = existingPayload + } + } + + result, requestID, err := c.universalPollUntilDone(ctx, executor, request, modelType, upstreamTaskID, payload, firstNonEmptyString(submitRequestID, upstreamTaskID), startedAt) + if err != nil { + return Response{}, err + } + finishedAt := time.Now() + return Response{ + Result: normalizeUniversalResult(request, result, upstreamTaskID), + RequestID: firstNonEmptyString(requestID, submitRequestID, requestIDFromResult(result), upstreamTaskID), + Progress: universalProgress(request, upstreamTaskID), + ResponseStartedAt: startedAt, + ResponseFinishedAt: finishedAt, + ResponseDurationMS: responseDurationMS(startedAt, finishedAt), + }, nil +} + +func (c UniversalClient) universalGetParams(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string) (map[string]any, error) { + if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customGetParamsScript", "custom_get_params_script"); scriptText != "" { + scriptContext := universalScriptContext(request, modelType, nil) + out, err := executor.Execute(ctx, scriptengine.Options{ + Script: scriptText, + Args: []any{cloneBody(request.Body), scriptContext}, + ContextData: scriptContext, + ScriptName: "custom_get_params_script:" + modelType, + PreferredEntryNames: []string{"getGenerateParams", "getParams", "main", "handler"}, + Timeout: 30 * time.Second, + HTTPClient: httpClient(request.HTTPClient, c.HTTPClient), + }) + if err != nil { + return nil, universalScriptError(err) + } + if params, ok := out.(map[string]any); ok && params != nil { + if params["_originalParams"] == nil { + params["_originalParams"] = cloneBody(request.Body) + } + return params, nil + } + return nil, &ClientError{Code: "invalid_response", Message: "custom get params script must return an object", Retryable: false} + } + body := universalDefaultPayload(request) + body["_originalParams"] = cloneBody(request.Body) + return body, nil +} + +func (c UniversalClient) universalSubmit(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, payload map[string]any) (map[string]any, string, error) { + if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customSubmitScript", "custom_submit_script"); scriptText != "" { + scriptContext := universalScriptContext(request, modelType, payload) + out, err := executor.Execute(ctx, scriptengine.Options{ + Script: scriptText, + Args: []any{cloneBody(payload), scriptContext}, + ContextData: scriptContext, + ScriptName: "custom_submit_script:" + modelType, + PreferredEntryNames: []string{"submitTask", "submitParams", "submit", "main", "handler"}, + Timeout: 30 * time.Second, + HTTPClient: httpClient(request.HTTPClient, c.HTTPClient), + }) + if err != nil { + return nil, "", universalScriptError(err) + } + result, ok := out.(map[string]any) + if !ok || result == nil { + return nil, "", &ClientError{Code: "invalid_response", Message: "custom submit script must return an object", Retryable: false} + } + return result, requestIDFromResult(result), nil + } + endpoint := universalSubmitEndpoint(request) + result, requestID, err := universalPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), request.Candidate.BaseURL, endpoint, universalStripPrivatePayload(payload), request.Candidate.Credentials) + return result, requestID, err +} + +func (c UniversalClient) universalPollUntilDone(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any, requestID string, startedAt time.Time) (map[string]any, string, error) { + interval := universalDurationConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms") + timeout := universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs") + deadline := time.NewTimer(timeout) + defer deadline.Stop() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + var lastResult map[string]any + for { + pollStarted := time.Now() + result, pollRequestID, err := c.universalPoll(ctx, executor, request, modelType, upstreamTaskID, payload) + pollFinished := time.Now() + if err != nil { + return nil, "", annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished) + } + lastResult = result + requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID) + if isUniversalSuccess(result) { + return result, requestID, nil + } + if isUniversalFailure(result) { + return nil, "", universalFailureError(result, requestID, startedAt) + } + select { + case <-ctx.Done(): + return nil, "", &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true} + case <-deadline.C: + return nil, "", &ClientError{Code: "timeout", Message: fmt.Sprintf("universal task %s did not finish before timeout; last status: %s", upstreamTaskID, universalStatus(lastResult)), RequestID: requestID, Retryable: true} + case <-ticker.C: + } + } +} + +func (c UniversalClient) universalPoll(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) { + if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customPollScript", "custom_poll_script"); scriptText != "" { + scriptContext := universalScriptContext(request, modelType, payload) + out, err := executor.Execute(ctx, scriptengine.Options{ + Script: scriptText, + Args: []any{upstreamTaskID, scriptContext}, + ContextData: scriptContext, + ScriptName: "custom_poll_script:" + modelType, + PreferredEntryNames: []string{"pollTask", "poll", "main", "handler"}, + Timeout: 30 * time.Second, + HTTPClient: httpClient(request.HTTPClient, c.HTTPClient), + }) + if err != nil { + return nil, "", universalScriptError(err) + } + result, ok := out.(map[string]any) + if !ok || result == nil { + return nil, "", &ClientError{Code: "invalid_response", Message: "custom poll script must return an object", Retryable: false} + } + return result, requestIDFromResult(result), nil + } + pollURL := resolveUniversalTaskURL(request.Candidate.PlatformConfig, upstreamTaskID) + if pollURL == "" { + return nil, "", &ClientError{Code: "missing_configuration", Message: "universal getTaskURL is required", Retryable: false} + } + return universalGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), pollURL, request.Candidate.Credentials) +} + +func universalScriptContext(request Request, modelType string, payload map[string]any) map[string]any { + baseURL := strings.TrimRight(strings.TrimSpace(request.Candidate.BaseURL), "/") + getTaskURL := universalConfigString(request.Candidate.PlatformConfig, "getTaskURL", "get_task_url") + context := map[string]any{ + "__easyaiScriptContext": true, + "baseURL": baseURL, + "getTaskURL": getTaskURL, + "authValues": cloneMapAny(request.Candidate.Credentials), + "headers": map[string]any{}, + "payload": cloneMapAny(payload), + "type": modelType, + "options": map[string]any{ + "task_id": request.RemoteTaskID, + "upstream_task_id": request.RemoteTaskID, + "model": request.Model, + "providerModelName": request.Candidate.ProviderModelName, + "platformId": request.Candidate.PlatformID, + "platformModelId": request.Candidate.PlatformModelID, + "canonicalModelKey": request.Candidate.CanonicalModelKey, + "modelType": modelType, + "timeout": universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms").Milliseconds(), + }, + "env": cloneMapAny(request.Candidate.PlatformConfig), + "candidate": universalCandidateSnapshot(request), + } + context["createRequestURL"] = func(path string, base ...string) string { + selectedBase := baseURL + if len(base) > 0 && strings.TrimSpace(base[0]) != "" { + selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/") + } + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path + } + return selectedBase + "/" + strings.TrimLeft(path, "/") + } + context["creatRequestURL"] = context["createRequestURL"] + context["resolveGetTaskURL"] = func(taskID string) string { + return resolveUniversalTaskURL(request.Candidate.PlatformConfig, taskID) + } + return context +} + +func universalCandidateSnapshot(request Request) map[string]any { + return map[string]any{ + "modelName": request.Candidate.ModelName, + "modelAlias": request.Candidate.ModelAlias, + "providerModelName": request.Candidate.ProviderModelName, + "provider": request.Candidate.Provider, + "platformId": request.Candidate.PlatformID, + "platformModelId": request.Candidate.PlatformModelID, + "capabilities": cloneMapAny(request.Candidate.Capabilities), + } +} + +func universalDefaultPayload(request Request) map[string]any { + body := cloneBody(request.Body) + body["model"] = upstreamModelName(request.Candidate) + if request.Kind == "images.generations" { + if n := firstPresent(body["n"], body["numImages"]); n != nil { + body["numImages"] = n + } + if aspectRatio := strings.TrimSpace(stringFromAny(body["aspect_ratio"])); aspectRatio != "" { + body["aspectRatio"] = aspectRatio + } + } + return body +} + +func universalSubmitEndpoint(request Request) string { + if endpoint := universalConfigString(request.Candidate.PlatformConfig, "submitPath", "submit_path"); endpoint != "" { + return endpoint + } + switch request.Kind { + case "images.generations": + return "/images/generations" + case "images.edits": + return "/images/edits" + case "videos.generations": + return "/video/generations" + default: + return "/" + strings.ReplaceAll(request.Kind, ".", "/") + } +} + +func universalPostJSON(ctx context.Context, client *http.Client, baseURL string, endpoint string, body map[string]any, credentials map[string]any) (map[string]any, string, error) { + raw, _ := json.Marshal(body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, providerURL(baseURL, endpoint), bytes.NewReader(raw)) + if err != nil { + return nil, "", err + } + req.Header.Set("Content-Type", "application/json") + if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + resp, err := client.Do(req) + if err != nil { + return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} + } + requestID := requestIDFromHTTPResponse(resp) + result, err := decodeHTTPResponse(resp) + return result, requestID, err +} + +func universalGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any) (map[string]any, string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, "", err + } + if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + resp, err := client.Do(req) + if err != nil { + return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} + } + requestID := requestIDFromHTTPResponse(resp) + result, err := decodeHTTPResponse(resp) + return result, requestID, err +} + +func normalizeUniversalResult(request Request, result map[string]any, upstreamTaskID string) map[string]any { + out := cloneMapAny(result) + if out["created"] == nil { + out["created"] = time.Now().UnixMilli() + } + if out["task_id"] == nil { + out["task_id"] = upstreamTaskID + } + if out["upstream_task_id"] == nil { + out["upstream_task_id"] = upstreamTaskID + } + if out["model"] == nil { + out["model"] = request.Model + } + if out["status"] == nil { + out["status"] = "success" + } + return out +} + +func universalScriptError(err error) error { + var scriptErr *scriptengine.Error + if strings.TrimSpace(err.Error()) == "" { + return &ClientError{Code: "script_error", Message: "script execution failed", Retryable: false} + } + if errors.As(err, &scriptErr) { + return &ClientError{Code: scriptErr.ErrorCode(), Message: scriptErr.Error(), Retryable: scriptErr.ErrorCode() == "script_timeout"} + } + return &ClientError{Code: "script_error", Message: err.Error(), Retryable: false} +} + +func universalFailureError(result map[string]any, requestID string, startedAt time.Time) error { + message := firstNonEmptyString(result["message"], result["error"], result["error_message"], "universal task failed") + return &ClientError{ + Code: firstNonEmptyString(result["code"], result["error_code"], "provider_failed"), + Message: message, + RequestID: requestID, + ResponseStartedAt: startedAt, + ResponseFinishedAt: time.Now(), + ResponseDurationMS: responseDurationMS(startedAt, time.Now()), + Retryable: false, + } +} + +func isUniversalSuccess(result map[string]any) bool { + switch universalStatus(result) { + case "success", "succeeded", "completed", "complete", "done": + return true + default: + return false + } +} + +func isUniversalFailure(result map[string]any) bool { + switch universalStatus(result) { + case "failed", "failure", "error", "cancelled", "canceled": + return true + default: + return false + } +} + +func universalStatus(result map[string]any) string { + return strings.ToLower(strings.TrimSpace(firstNonEmptyString(result["status"], result["state"], result["task_status"]))) +} + +func universalTaskID(result map[string]any) string { + return firstNonEmptyString(result["upstream_task_id"], result["task_id"], result["taskId"], result["id"]) +} + +func universalProgress(request Request, upstreamTaskID string) []Progress { + progress := providerProgress(request) + progress = append(progress, Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}) + return progress +} + +func universalStripPrivatePayload(payload map[string]any) map[string]any { + out := cloneMapAny(payload) + for _, key := range []string{"_originalParams", "_resolution", "_duration"} { + delete(out, key) + } + return out +} + +func universalSceneScript(config map[string]any, modelType string, keys ...string) string { + for _, key := range keys { + value := config[key] + switch typed := value.(type) { + case string: + if strings.TrimSpace(typed) != "" { + return strings.TrimSpace(typed) + } + case map[string]any: + if script := firstNonEmptyString(typed[modelType], typed["common"]); script != "" { + return script + } + } + } + return "" +} + +func universalConfigString(config map[string]any, keys ...string) string { + for _, key := range keys { + if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "" { + return value + } + } + return "" +} + +func universalDurationConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration { + for _, key := range keys { + switch value := config[key].(type) { + case int: + if value > 0 { + return time.Duration(value) * time.Millisecond + } + case int64: + if value > 0 { + return time.Duration(value) * time.Millisecond + } + case float64: + if value > 0 { + return time.Duration(value) * time.Millisecond + } + case string: + if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 { + return parsed + } + } + } + return fallback +} + +func resolveUniversalTaskURL(config map[string]any, upstreamTaskID string) string { + template := universalConfigString(config, "getTaskURL", "get_task_url") + out := strings.TrimSpace(template) + replacements := [][2]string{ + {"${upstream_task_id}", upstreamTaskID}, + {"{{upstream_task_id}}", upstreamTaskID}, + {"{upstream_task_id}", upstreamTaskID}, + {"${task_id}", upstreamTaskID}, + {"{{task_id}}", upstreamTaskID}, + {"{task_id}", upstreamTaskID}, + {"${taskId}", upstreamTaskID}, + {"${taskID}", upstreamTaskID}, + {"{{taskId}}", upstreamTaskID}, + {"{{taskID}}", upstreamTaskID}, + {"{taskId}", upstreamTaskID}, + {"{taskID}", upstreamTaskID}, + } + for _, replacement := range replacements { + out = strings.ReplaceAll(out, replacement[0], replacement[1]) + } + return out +} diff --git a/apps/api/internal/clients/universal_test.go b/apps/api/internal/clients/universal_test.go new file mode 100644 index 0000000..325004d --- /dev/null +++ b/apps/api/internal/clients/universal_test.go @@ -0,0 +1,132 @@ +package clients + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestUniversalClientRunsCustomScripts(t *testing.T) { + request := Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "custom-video", + Body: map[string]any{"model": "custom-video", "prompt": "hello"}, + Candidate: testUniversalCandidate(map[string]any{ + "customGetParamsScript": map[string]any{ + "video_generate": `async function getGenerateParams(params, context) { + return { prompt: params.prompt + "-payload", model: context.candidate.providerModelName }; + }`, + }, + "customSubmitScript": map[string]any{ + "video_generate": `async function submitTask(payload) { + return { status: "submitted", task_id: "task-" + payload.prompt }; + }`, + }, + "customPollScript": map[string]any{ + "video_generate": `async function pollTask(taskId) { + return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/video.mp4" }] }; + }`, + }, + "pollIntervalMs": 1, + "pollTimeoutMs": 1000, + }), + } + var submitted string + request.OnRemoteTaskSubmitted = func(remoteTaskID string, payload map[string]any) error { + submitted = remoteTaskID + return nil + } + + response, err := (UniversalClient{}).Run(context.Background(), request) + if err != nil { + t.Fatalf("run failed: %v", err) + } + if submitted != "task-hello-payload" { + t.Fatalf("unexpected remote task id: %q", submitted) + } + if response.Result["upstream_task_id"] != "task-hello-payload" { + t.Fatalf("unexpected result: %#v", response.Result) + } +} + +func TestUniversalClientDefaultSubmitAndPoll(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-key" { + t.Fatalf("missing authorization header") + } + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/video/generations": + _, _ = w.Write([]byte(`{"status":"submitted","task_id":"remote-1"}`)) + case "/tasks/remote-1": + _, _ = w.Write([]byte(`{"status":"success","data":[{"url":"https://cdn.example/default.mp4"}]}`)) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + defer server.Close() + + request := Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "default-video", + Body: map[string]any{"model": "default-video", "prompt": "hello"}, + Candidate: testUniversalCandidate(map[string]any{ + "getTaskURL": server.URL + "/tasks/{{task_id}}", + "pollIntervalMs": 1, + "pollTimeoutMs": 1000, + }), + } + request.Candidate.BaseURL = server.URL + + response, err := (UniversalClient{}).Run(context.Background(), request) + if err != nil { + t.Fatalf("run failed: %v", err) + } + if response.Result["upstream_task_id"] != "remote-1" { + t.Fatalf("unexpected result: %#v", response.Result) + } +} + +func TestUniversalClientResumeSkipsSubmit(t *testing.T) { + request := Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Model: "resume-video", + Body: map[string]any{"model": "resume-video", "prompt": "hello"}, + RemoteTaskID: "existing-1", + RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}}, + Candidate: testUniversalCandidate(map[string]any{ + "customSubmitScript": `async function submitTask() { throw new Error("submit should not run"); }`, + "customPollScript": `async function pollTask(taskId) { return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/resume.mp4" }] }; }`, + "pollIntervalMs": 1, + "pollTimeoutMs": 1000, + }), + } + + response, err := (UniversalClient{}).Run(context.Background(), request) + if err != nil { + t.Fatalf("run failed: %v", err) + } + if response.Result["upstream_task_id"] != "existing-1" { + t.Fatalf("unexpected result: %#v", response.Result) + } +} + +func testUniversalCandidate(config map[string]any) store.RuntimeModelCandidate { + return store.RuntimeModelCandidate{ + Provider: "universal", + SpecType: "universal", + BaseURL: "https://provider.example", + Credentials: map[string]any{"apiKey": "test-key"}, + PlatformConfig: config, + ModelName: "alias-model", + ProviderModelName: "provider-model", + ModelType: "video_generate", + ClientID: "universal:test", + } +} diff --git a/apps/api/internal/runner/param_processor_script.go b/apps/api/internal/runner/param_processor_script.go new file mode 100644 index 0000000..6914a81 --- /dev/null +++ b/apps/api/internal/runner/param_processor_script.go @@ -0,0 +1,196 @@ +package runner + +import ( + "context" + "fmt" + "reflect" + "strings" + + scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func (s *Service) preprocessRequestWithScripts(ctx context.Context, kind string, body map[string]any, candidate store.RuntimeModelCandidate) parameterPreprocessResult { + if platformConfigBool(candidate.PlatformConfig, "skipParamNormalization", "skip_param_normalization") { + modelType := strings.TrimSpace(candidate.ModelType) + if modelType == "" { + modelType = modelTypeFromKind(kind, body) + } + input := cloneMap(body) + return parameterPreprocessResult{ + Body: cloneMap(body), + Log: parameterPreprocessingLog{ + ModelType: modelType, + Input: input, + Output: cloneMap(body), + Changed: false, + Changes: []parameterPreprocessChange{}, + Model: preprocessingModelSnapshot(candidate), + }, + } + } + + result := preprocessRequestWithLog(kind, body, candidate) + if result.Err != nil { + return result + } + scriptText := platformConfigString(candidate.PlatformConfig, "customPreprocessScript", "custom_preprocess_script") + if strings.TrimSpace(scriptText) == "" || s.scriptExecutor == nil { + return result + } + + before := cloneMap(result.Body) + scriptContext := s.scriptContext(candidate, result.Log.ModelType, nil, map[string]any{ + "modelCapability": effectiveModelCapability(candidate), + "platformModel": result.Log.Model, + "platform": candidate.PlatformConfig, + }) + out, err := s.scriptExecutor.Execute(ctx, scriptengine.Options{ + Script: scriptText, + Args: []any{cloneMap(result.Body), result.Log.ModelType, scriptContext}, + ContextData: scriptContext, + ScriptName: "custom_preprocess_script:" + result.Log.ModelType, + PreferredEntryNames: []string{"preprocessParams", "preprocess", "main", "handler"}, + Timeout: scriptengine.PreprocessTimeout, + }) + if err != nil { + result.Log.recordScriptChange("CustomPreprocessScript", "error", "$", before, result.Body, err.Error()) + result.Log.Output = cloneMap(result.Body) + result.Log.Changed = len(result.Log.Changes) > 0 + result.Err = err + return result + } + rewritten, ok := out.(map[string]any) + if !ok || rewritten == nil { + result.Log.Output = cloneMap(result.Body) + result.Log.Changed = len(result.Log.Changes) > 0 + return result + } + merged := cloneMap(result.Body) + for key, value := range rewritten { + merged[key] = value + } + if !mapsEqual(before, merged) { + result.Log.recordScriptChange("CustomPreprocessScript", "rewrite", "$", before, merged, "platform custom preprocess script returned parameter updates") + } + result.Body = merged + result.Log.Output = cloneMap(merged) + result.Log.Changed = len(result.Log.Changes) > 0 + return result +} + +func (s *Service) scriptContext(candidate store.RuntimeModelCandidate, modelType string, payload map[string]any, extra map[string]any) map[string]any { + getTaskURL := platformConfigString(candidate.PlatformConfig, "getTaskURL", "get_task_url") + baseURL := strings.TrimRight(strings.TrimSpace(candidate.BaseURL), "/") + env := cloneMap(candidate.PlatformConfig) + context := map[string]any{ + "__easyaiScriptContext": true, + "baseURL": baseURL, + "getTaskURL": getTaskURL, + "authValues": cloneMap(candidate.Credentials), + "headers": map[string]any{}, + "payload": cloneMap(payload), + "type": modelType, + "options": map[string]any{ + "model": candidate.ModelName, + "providerModelName": candidate.ProviderModelName, + "platformId": candidate.PlatformID, + "platformModelId": candidate.PlatformModelID, + "canonicalModelKey": candidate.CanonicalModelKey, + "sourceProviderCode": candidate.Provider, + }, + "env": env, + "candidate": preprocessingModelSnapshot(candidate), + } + context["createRequestURL"] = func(path string, base ...string) string { + selectedBase := baseURL + if len(base) > 0 && strings.TrimSpace(base[0]) != "" { + selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/") + } + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path + } + return selectedBase + "/" + strings.TrimLeft(path, "/") + } + context["creatRequestURL"] = context["createRequestURL"] + context["resolveGetTaskURL"] = func(taskID string) string { + return resolveTaskURLTemplate(getTaskURL, taskID, "") + } + for key, value := range extra { + context[key] = value + } + return context +} + +func preprocessingModelSnapshot(candidate store.RuntimeModelCandidate) map[string]any { + return map[string]any{ + "modelName": candidate.ModelName, + "modelAlias": candidate.ModelAlias, + "providerModelName": candidate.ProviderModelName, + "provider": candidate.Provider, + "platformId": candidate.PlatformID, + "platformModelId": candidate.PlatformModelID, + "capabilities": cloneMap(candidate.Capabilities), + } +} + +func (log *parameterPreprocessingLog) recordScriptChange(processor string, action string, path string, before any, after any, reason string) { + if log == nil { + return + } + log.Changes = append(log.Changes, parameterPreprocessChange{ + Processor: processor, + Action: action, + Path: path, + Before: cloneAny(before), + After: cloneAny(after), + Reason: reason, + }) +} + +func platformConfigString(config map[string]any, keys ...string) string { + for _, key := range keys { + if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "" { + return value + } + } + return "" +} + +func platformConfigBool(config map[string]any, keys ...string) bool { + for _, key := range keys { + switch value := config[key].(type) { + case bool: + return value + case string: + return strings.EqualFold(strings.TrimSpace(value), "true") + } + } + return false +} + +func resolveTaskURLTemplate(template string, upstreamTaskID string, taskID string) string { + out := strings.TrimSpace(template) + replacements := [][2]string{ + {"${upstream_task_id}", upstreamTaskID}, + {"{{upstream_task_id}}", upstreamTaskID}, + {"{upstream_task_id}", upstreamTaskID}, + {"${task_id}", taskID}, + {"{{task_id}}", taskID}, + {"{task_id}", taskID}, + {"${taskId}", upstreamTaskID}, + {"${taskID}", upstreamTaskID}, + {"{{taskId}}", upstreamTaskID}, + {"{{taskID}}", upstreamTaskID}, + {"{taskId}", upstreamTaskID}, + {"{taskID}", upstreamTaskID}, + } + for _, replacement := range replacements { + out = strings.ReplaceAll(out, replacement[0], replacement[1]) + } + return out +} + +func mapsEqual(left map[string]any, right map[string]any) bool { + return reflect.DeepEqual(left, right) +} diff --git a/apps/api/internal/runner/param_processor_script_test.go b/apps/api/internal/runner/param_processor_script_test.go new file mode 100644 index 0000000..454cc15 --- /dev/null +++ b/apps/api/internal/runner/param_processor_script_test.go @@ -0,0 +1,64 @@ +package runner + +import ( + "context" + "testing" + + scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestPreprocessRequestWithCustomScript(t *testing.T) { + service := &Service{scriptExecutor: &scriptengine.Executor{}} + candidate := store.RuntimeModelCandidate{ + Provider: "universal", + ModelName: "image-model", + ModelType: "image_generate", + Capabilities: map[string]any{ + "image_generate": map[string]any{"max_output_images": 4}, + }, + PlatformConfig: map[string]any{ + "customPreprocessScript": `(params, type, context) => { + return { prompt: params.prompt + "-" + type, n: 2, provider: context.candidate.provider }; + }`, + }, + } + + result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 8}, candidate) + if result.Err != nil { + t.Fatalf("unexpected preprocess error: %v", result.Err) + } + if result.Body["prompt"] != "hello-image_generate" || result.Body["n"].(float64) != 2 { + t.Fatalf("unexpected body: %#v", result.Body) + } + if !result.Log.Changed || len(result.Log.Changes) == 0 { + t.Fatalf("expected script change in log: %#v", result.Log) + } +} + +func TestPreprocessRequestSkipParamNormalizationSkipsCustomScript(t *testing.T) { + service := &Service{scriptExecutor: &scriptengine.Executor{}} + candidate := store.RuntimeModelCandidate{ + ModelName: "image-model", + ModelType: "image_generate", + Provider: "universal", + Capabilities: map[string]any{ + "image_generate": map[string]any{"max_output_images": 1}, + }, + PlatformConfig: map[string]any{ + "skipParamNormalization": true, + "customPreprocessScript": `(params) => ({ prompt: "changed", n: 1 })`, + }, + } + + result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 9}, candidate) + if result.Err != nil { + t.Fatalf("unexpected preprocess error: %v", result.Err) + } + if result.Body["prompt"] != "hello" || result.Body["n"].(int) != 9 { + t.Fatalf("skip should keep raw body, got %#v", result.Body) + } + if result.Log.Changed || len(result.Log.Changes) != 0 { + t.Fatalf("skip should not record changes: %#v", result.Log) + } +} diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index 85d4be3..c626ca7 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -12,18 +12,20 @@ import ( "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/jackc/pgx/v5" "github.com/riverqueue/river" ) type Service struct { - cfg config.Config - store *store.Store - logger *slog.Logger - clients map[string]clients.Client - httpClients *httpClientCache - riverClient *river.Client[pgx.Tx] + cfg config.Config + store *store.Store + logger *slog.Logger + clients map[string]clients.Client + scriptExecutor *scriptengine.Executor + httpClients *httpClientCache + riverClient *river.Client[pgx.Tx] } type Result struct { @@ -47,17 +49,29 @@ func (e *TaskQueuedError) Is(target error) bool { func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service { httpClients := newHTTPClientCache() + scriptExecutor := &scriptengine.Executor{Logger: logger} return &Service{ - cfg: cfg, - store: db, - logger: logger, + cfg: cfg, + store: db, + logger: logger, + scriptExecutor: scriptExecutor, clients: map[string]clients.Client{ - "openai": clients.OpenAIClient{HTTPClient: httpClients.none}, - "gemini": clients.GeminiClient{HTTPClient: httpClients.none}, - "volces": clients.VolcesClient{HTTPClient: httpClients.none}, - "keling": clients.KelingClient{HTTPClient: httpClients.none}, - "kling": clients.KelingClient{HTTPClient: httpClients.none}, - "simulation": clients.SimulationClient{}, + "openai": clients.OpenAIClient{HTTPClient: httpClients.none}, + "aliyun-bailian": clients.AliyunBailianClient{HTTPClient: httpClients.none}, + "blackforest": clients.BlackforestClient{HTTPClient: httpClients.none}, + "gemini": clients.GeminiClient{HTTPClient: httpClients.none}, + "jimeng": clients.JimengClient{HTTPClient: httpClients.none}, + "midjourney": clients.MidjourneyClient{HTTPClient: httpClients.none}, + "minimax": clients.MinimaxClient{HTTPClient: httpClients.none}, + "newapi": clients.NewAPIClient{HTTPClient: httpClients.none}, + "tencent-hunyuan-image": clients.HunyuanImageClient{HTTPClient: httpClients.none}, + "tencent-hunyuan-video": clients.HunyuanVideoClient{HTTPClient: httpClients.none}, + "vidu": clients.ViduClient{HTTPClient: httpClients.none}, + "volces": clients.VolcesClient{HTTPClient: httpClients.none}, + "keling": clients.KelingClient{HTTPClient: httpClients.none}, + "kling": clients.KelingClient{HTTPClient: httpClients.none}, + "universal": clients.UniversalClient{HTTPClient: httpClients.none, ScriptExecutor: scriptExecutor}, + "simulation": clients.SimulationClient{}, }, httpClients: httpClients, } @@ -147,7 +161,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut attemptNo := task.AttemptCount var firstPreprocessing parameterPreprocessingLog if len(candidates) > 0 { - preprocessing := preprocessRequestWithLog(task.Kind, body, candidates[0]) + preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0]) firstCandidateBody = preprocessing.Body firstPreprocessing = preprocessing.Log normalizedModelType = candidates[0].ModelType @@ -225,7 +239,7 @@ candidatesLoop: var candidateErr error for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ { nextAttemptNo := attemptNo + 1 - preprocessing := preprocessRequestWithLog(task.Kind, body, candidate) + preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidate) preprocessingLog := preprocessing.Log lastPreprocessing = &preprocessingLog if preprocessing.Err != nil { @@ -1090,8 +1104,13 @@ func parameterPreprocessClientError(err error) *clients.ClientError { if err == nil { return nil } + code := "invalid_parameter" + var coded interface{ ErrorCode() string } + if errors.As(err, &coded) && strings.TrimSpace(coded.ErrorCode()) != "" { + code = coded.ErrorCode() + } return &clients.ClientError{ - Code: "invalid_parameter", + Code: code, Message: err.Error(), StatusCode: 400, Retryable: false, diff --git a/apps/api/internal/script/executor.go b/apps/api/internal/script/executor.go new file mode 100644 index 0000000..1af9d23 --- /dev/null +++ b/apps/api/internal/script/executor.go @@ -0,0 +1,530 @@ +package script + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "regexp" + "strings" + "sync" + "time" + + "github.com/dop251/goja" + "github.com/dop251/goja_nodejs/eventloop" +) + +const ( + DefaultTimeout = 30 * time.Second + PreprocessTimeout = 10 * time.Second +) + +type Logger interface { + Debug(msg string, args ...any) + Info(msg string, args ...any) + Warn(msg string, args ...any) + Error(msg string, args ...any) +} + +type Executor struct { + HTTPClient *http.Client + Logger Logger +} + +type Options struct { + Script string + Args []any + ContextData map[string]any + ScriptName string + PreferredEntryNames []string + Timeout time.Duration + HTTPClient *http.Client +} + +type Error struct { + Code string + Message string +} + +func (e *Error) Error() string { + if e == nil { + return "" + } + if strings.TrimSpace(e.Message) != "" { + return e.Message + } + return e.Code +} + +func (e *Error) ErrorCode() string { + if e == nil || strings.TrimSpace(e.Code) == "" { + return "script_error" + } + return e.Code +} + +type result struct { + value any + err error +} + +var ( + functionDeclarationPattern = regexp.MustCompile(`(?:^|\n)\s*(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(`) + assignedFunctionPattern = regexp.MustCompile(`(?:^|\n)\s*(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?(?:function\b|\([^)]*\)\s*=>|[A-Za-z_$][\w$]*\s*=>)`) +) + +func (e Executor) Execute(ctx context.Context, opts Options) (any, error) { + scriptText := strings.TrimSpace(opts.Script) + if scriptText == "" { + return nil, &Error{Code: "script_empty", Message: "script is empty"} + } + scriptName := strings.TrimSpace(opts.ScriptName) + if scriptName == "" { + scriptName = "script" + } + timeout := opts.Timeout + if timeout <= 0 { + timeout = DefaultTimeout + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + loop := eventloop.NewEventLoop(eventloop.EnableConsole(false)) + loop.Start() + defer loop.Terminate() + + resultCh := make(chan result, 1) + var once sync.Once + finish := func(value any, err error) { + once.Do(func() { + resultCh <- result{value: value, err: err} + loop.StopNoWait() + }) + } + + ok := loop.RunOnLoop(func(vm *goja.Runtime) { + e.installRuntime(ctx, loop, vm, opts.HTTPClient, scriptName) + for key, value := range opts.ContextData { + _ = vm.Set(key, value) + } + value, err := e.invoke(vm, scriptText, opts.Args, opts.PreferredEntryNames, scriptName) + if err != nil { + finish(nil, err) + return + } + e.resolveValue(vm, value, finish) + }) + if !ok { + return nil, &Error{Code: "script_runtime_error", Message: "script event loop is not available"} + } + + select { + case out := <-resultCh: + if out.err != nil { + return nil, out.err + } + return normalizeExport(out.value), nil + case <-ctx.Done(): + loop.Terminate() + code := "script_timeout" + if errors.Is(ctx.Err(), context.Canceled) { + code = "script_cancelled" + } + return nil, &Error{Code: code, Message: fmt.Sprintf("%s exceeded %s", scriptName, timeout)} + } +} + +func (e Executor) invoke(vm *goja.Runtime, scriptText string, args []any, preferred []string, scriptName string) (goja.Value, error) { + if fnValue, err := vm.RunString("(" + scriptText + ")"); err == nil { + if fn, ok := goja.AssertFunction(fnValue); ok { + return fn(goja.Undefined(), values(vm, args)...) + } + } + + if _, err := vm.RunString(scriptText); err != nil { + return nil, &Error{Code: "script_compile_error", Message: err.Error()} + } + + for _, name := range entryCandidates(scriptText, preferred) { + fnValue, err := vm.RunString(fmt.Sprintf("(typeof %s === 'function' ? %s : undefined)", name, name)) + if err != nil || goja.IsUndefined(fnValue) || goja.IsNull(fnValue) { + continue + } + fn, ok := goja.AssertFunction(fnValue) + if !ok { + continue + } + return fn(goja.Undefined(), values(vm, args)...) + } + + return nil, &Error{Code: "script_entry_missing", Message: fmt.Sprintf("%s must expose an executable function", scriptName)} +} + +func (e Executor) resolveValue(vm *goja.Runtime, value goja.Value, finish func(any, error)) { + if value == nil { + finish(nil, nil) + return + } + if promise, ok := value.Export().(*goja.Promise); ok { + switch promise.State() { + case goja.PromiseStateFulfilled: + finish(exportValue(promise.Result()), nil) + case goja.PromiseStateRejected: + finish(nil, &Error{Code: "script_error", Message: stringify(promise.Result())}) + default: + obj := value.ToObject(vm) + thenFn, ok := goja.AssertFunction(obj.Get("then")) + if !ok { + finish(nil, &Error{Code: "script_error", Message: "promise.then is not callable"}) + return + } + onResolve := func(call goja.FunctionCall) goja.Value { + finish(exportValue(call.Argument(0)), nil) + return goja.Undefined() + } + onReject := func(call goja.FunctionCall) goja.Value { + finish(nil, &Error{Code: "script_error", Message: stringify(call.Argument(0))}) + return goja.Undefined() + } + _, _ = thenFn(obj, vm.ToValue(onResolve), vm.ToValue(onReject)) + } + return + } + finish(exportValue(value), nil) +} + +func (e Executor) installRuntime(ctx context.Context, loop *eventloop.EventLoop, vm *goja.Runtime, client *http.Client, scriptName string) { + vm.SetFieldNameMapper(goja.TagFieldNameMapper("json", true)) + e.installConsole(vm, scriptName) + e.installHTTP(ctx, loop, vm, firstHTTPClient(client, e.HTTPClient), scriptName) + _ = vm.Set("FormData", formDataConstructor(vm)) + _, _ = vm.RunString(` + function __easyaiGotRequest(method, url, options) { + return { + json: function() { return __easyaiHTTP(method, url, options || {}).then(function(resp) { return resp.json(); }); }, + text: function() { return __easyaiHTTP(method, url, options || {}).then(function(resp) { return resp.text(); }); } + }; + } + var got = { + get: function(url, options) { return __easyaiGotRequest("GET", url, options); }, + post: function(url, options) { return __easyaiGotRequest("POST", url, options); }, + put: function(url, options) { return __easyaiGotRequest("PUT", url, options); }, + patch: function(url, options) { return __easyaiGotRequest("PATCH", url, options); }, + delete: function(url, options) { return __easyaiGotRequest("DELETE", url, options); }, + extend: function() { return this; } + }; + function fetch(url, options) { + options = options || {}; + return __easyaiHTTP(options.method || "GET", url, options); + } + `) +} + +func (e Executor) installConsole(vm *goja.Runtime, scriptName string) { + log := func(level string, args ...any) { + if e.Logger == nil { + return + } + values := make([]any, 0, len(args)+1) + values = append(values, "script", scriptName) + values = append(values, args...) + switch level { + case "error": + e.Logger.Error("script console", values...) + case "warn": + e.Logger.Warn("script console", values...) + case "info": + e.Logger.Info("script console", values...) + default: + e.Logger.Debug("script console", values...) + } + } + _ = vm.Set("console", map[string]any{ + "log": func(args ...any) { log("debug", args...) }, + "debug": func(args ...any) { log("debug", args...) }, + "info": func(args ...any) { log("info", args...) }, + "warn": func(args ...any) { log("warn", args...) }, + "error": func(args ...any) { log("error", args...) }, + }) +} + +func (e Executor) installHTTP(ctx context.Context, loop *eventloop.EventLoop, vm *goja.Runtime, client *http.Client, scriptName string) { + _ = vm.Set("__easyaiHTTP", func(call goja.FunctionCall) goja.Value { + method := strings.ToUpper(strings.TrimSpace(call.Argument(0).String())) + if method == "" { + method = http.MethodGet + } + url := strings.TrimSpace(call.Argument(1).String()) + options := exportMap(call.Argument(2)) + promise, resolve, reject := vm.NewPromise() + go func() { + response, err := doHTTPRequest(ctx, client, method, url, options) + loop.RunOnLoop(func(runtime *goja.Runtime) { + if err != nil { + _ = reject(err.Error()) + return + } + _ = resolve(httpResponseObject(runtime, response)) + }) + }() + return vm.ToValue(promise) + }) +} + +func doHTTPRequest(ctx context.Context, client *http.Client, method string, url string, options map[string]any) (httpScriptResponse, error) { + if strings.TrimSpace(url) == "" { + return httpScriptResponse{}, errors.New("url is required") + } + var body io.Reader + headers := map[string]string{} + if rawHeaders, ok := options["headers"].(map[string]any); ok { + for key, value := range rawHeaders { + if text := strings.TrimSpace(fmt.Sprint(value)); text != "" { + headers[key] = text + } + } + } + if jsonBody, ok := options["json"]; ok { + raw, err := json.Marshal(jsonBody) + if err != nil { + return httpScriptResponse{}, err + } + body = bytes.NewReader(raw) + if _, ok := headers["Content-Type"]; !ok { + headers["Content-Type"] = "application/json" + } + } else if rawBody, ok := options["body"]; ok { + body, headers = requestBody(rawBody, headers) + } + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return httpScriptResponse{}, err + } + for key, value := range headers { + req.Header.Set(key, value) + } + resp, err := client.Do(req) + if err != nil { + return httpScriptResponse{}, err + } + defer resp.Body.Close() + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024)) + out := httpScriptResponse{ + Status: resp.Status, + StatusCode: resp.StatusCode, + OK: resp.StatusCode >= 200 && resp.StatusCode < 300, + Headers: map[string]any{}, + Body: string(raw), + } + for key, values := range resp.Header { + if len(values) == 1 { + out.Headers[key] = values[0] + } else { + out.Headers[key] = values + } + } + if len(raw) > 0 { + var parsed any + if json.Unmarshal(raw, &parsed) == nil { + out.JSON = parsed + } + } + return out, nil +} + +type httpScriptResponse struct { + Status string + StatusCode int + OK bool + Headers map[string]any + Body string + JSON any +} + +func httpResponseObject(vm *goja.Runtime, response httpScriptResponse) map[string]any { + return map[string]any{ + "status": response.StatusCode, + "statusCode": response.StatusCode, + "ok": response.OK, + "headers": response.Headers, + "text": func() string { + return response.Body + }, + "json": func() any { + if response.JSON != nil { + return response.JSON + } + var parsed any + if json.Unmarshal([]byte(response.Body), &parsed) == nil { + return parsed + } + panic(vm.NewTypeError("response body is not valid JSON")) + }, + } +} + +func requestBody(value any, headers map[string]string) (io.Reader, map[string]string) { + switch typed := value.(type) { + case string: + return strings.NewReader(typed), headers + case []byte: + return bytes.NewReader(typed), headers + case map[string]any: + if fields, ok := typed["__easyaiFormData"].([]any); ok { + var buf bytes.Buffer + writer := multipart.NewWriter(&buf) + for _, rawField := range fields { + field, ok := rawField.(map[string]any) + if !ok { + continue + } + _ = writer.WriteField(strings.TrimSpace(fmt.Sprint(field["name"])), fmt.Sprint(field["value"])) + } + _ = writer.Close() + headers["Content-Type"] = writer.FormDataContentType() + return &buf, headers + } + raw, _ := json.Marshal(typed) + headers["Content-Type"] = "application/json" + return bytes.NewReader(raw), headers + default: + raw, _ := json.Marshal(typed) + headers["Content-Type"] = "application/json" + return bytes.NewReader(raw), headers + } +} + +func formDataConstructor(vm *goja.Runtime) func(goja.ConstructorCall) *goja.Object { + return func(call goja.ConstructorCall) *goja.Object { + obj := call.This + _ = obj.Set("__easyaiFormData", []any{}) + _ = obj.Set("append", func(name string, value any) { + fields := exportSlice(obj.Get("__easyaiFormData")) + fields = append(fields, map[string]any{"name": name, "value": value}) + _ = obj.Set("__easyaiFormData", fields) + }) + return obj + } +} + +func entryCandidates(scriptText string, preferred []string) []string { + values := make([]string, 0, len(preferred)+4) + appendUnique := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + for _, existing := range values { + if existing == value { + return + } + } + values = append(values, value) + } + for _, value := range preferred { + appendUnique(value) + } + for _, match := range functionDeclarationPattern.FindAllStringSubmatch(scriptText, -1) { + appendUnique(match[1]) + } + for _, match := range assignedFunctionPattern.FindAllStringSubmatch(scriptText, -1) { + appendUnique(match[1]) + } + appendUnique("main") + appendUnique("handler") + return values +} + +func values(vm *goja.Runtime, input []any) []goja.Value { + out := make([]goja.Value, 0, len(input)) + for _, item := range input { + out = append(out, toValue(vm, item)) + } + return out +} + +func toValue(vm *goja.Runtime, item any) goja.Value { + if values, ok := item.(map[string]any); ok { + copied := map[string]any{} + for key, value := range values { + if key == "__easyaiScriptContext" { + continue + } + copied[key] = value + } + obj := vm.ToValue(copied).ToObject(vm) + if marker, _ := values["__easyaiScriptContext"].(bool); marker { + _ = obj.Set("got", vm.Get("got")) + _ = obj.Set("fetch", vm.Get("fetch")) + _ = obj.Set("FormData", vm.Get("FormData")) + } + return obj + } + return vm.ToValue(item) +} + +func exportValue(value goja.Value) any { + if value == nil || goja.IsUndefined(value) || goja.IsNull(value) { + return nil + } + return value.Export() +} + +func exportMap(value goja.Value) map[string]any { + if value == nil || goja.IsUndefined(value) || goja.IsNull(value) { + return map[string]any{} + } + if typed, ok := normalizeExport(value.Export()).(map[string]any); ok { + return typed + } + return map[string]any{} +} + +func exportSlice(value goja.Value) []any { + if value == nil || goja.IsUndefined(value) || goja.IsNull(value) { + return []any{} + } + if typed, ok := normalizeExport(value.Export()).([]any); ok { + return typed + } + return []any{} +} + +func normalizeExport(value any) any { + raw, err := json.Marshal(value) + if err != nil { + return value + } + var out any + if json.Unmarshal(raw, &out) != nil { + return value + } + return out +} + +func firstHTTPClient(values ...*http.Client) *http.Client { + for _, value := range values { + if value != nil { + return value + } + } + return http.DefaultClient +} + +func stringify(value goja.Value) string { + if value == nil || goja.IsUndefined(value) || goja.IsNull(value) { + return "script rejected" + } + if exported, ok := normalizeExport(value.Export()).(map[string]any); ok { + for _, key := range []string{"message", "error", "error_message"} { + if message := strings.TrimSpace(fmt.Sprint(exported[key])); message != "" && message != "" { + return message + } + } + } + return strings.TrimSpace(value.String()) +} diff --git a/apps/api/internal/script/executor_test.go b/apps/api/internal/script/executor_test.go new file mode 100644 index 0000000..57d4f2b --- /dev/null +++ b/apps/api/internal/script/executor_test.go @@ -0,0 +1,116 @@ +package script + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestExecutorRunsFunctionExpression(t *testing.T) { + out, err := (Executor{}).Execute(context.Background(), Options{ + Script: `(params) => ({ prompt: params.prompt.toUpperCase(), n: 2 })`, + Args: []any{map[string]any{"prompt": "hello"}}, + ScriptName: "custom_preprocess_script", + Timeout: time.Second, + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + result := out.(map[string]any) + if result["prompt"] != "HELLO" || result["n"].(float64) != 2 { + t.Fatalf("unexpected result: %#v", result) + } +} + +func TestExecutorSelectsPreferredEntry(t *testing.T) { + out, err := (Executor{}).Execute(context.Background(), Options{ + Script: ` + function helper() { return { wrong: true }; } + async function submitTask(payload, context) { + return { status: "submitted", task_id: payload.id, baseURL: context.baseURL }; + } + `, + Args: []any{map[string]any{"id": "task-1"}, map[string]any{"baseURL": "https://example.test"}}, + ContextData: map[string]any{"baseURL": "https://example.test"}, + PreferredEntryNames: []string{"submitTask", "submit"}, + ScriptName: "custom_submit_script:video_generate", + Timeout: time.Second, + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + result := out.(map[string]any) + if result["task_id"] != "task-1" || result["baseURL"] != "https://example.test" { + t.Fatalf("unexpected result: %#v", result) + } +} + +func TestExecutorGotJSONHelper(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("unexpected method: %s", r.Method) + } + if r.Header.Get("Authorization") != "Bearer test" { + t.Fatalf("missing authorization header") + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"success","task_id":"remote-1"}`)) + })) + defer server.Close() + + out, err := (Executor{}).Execute(context.Background(), Options{ + Script: ` + async function submitTask(payload, context) { + return await got.post(context.baseURL + "/tasks", { + headers: { Authorization: "Bearer " + context.authValues.apiKey }, + json: payload + }).json(); + } + `, + Args: []any{map[string]any{"prompt": "hello"}, map[string]any{"baseURL": server.URL, "authValues": map[string]any{"apiKey": "test"}}}, + ContextData: map[string]any{"baseURL": server.URL, "authValues": map[string]any{"apiKey": "test"}}, + PreferredEntryNames: []string{"submitTask"}, + ScriptName: "custom_submit_script:image_generate", + Timeout: 2 * time.Second, + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + result := out.(map[string]any) + if result["task_id"] != "remote-1" { + t.Fatalf("unexpected result: %#v", result) + } +} + +func TestExecutorTimeout(t *testing.T) { + _, err := (Executor{}).Execute(context.Background(), Options{ + Script: `async function main() { await new Promise((resolve) => setTimeout(resolve, 200)); return true; }`, + ScriptName: "custom_poll_script", + Timeout: 25 * time.Millisecond, + }) + if err == nil { + t.Fatal("expected timeout") + } + scriptErr, ok := err.(*Error) + if !ok || scriptErr.Code != "script_timeout" { + t.Fatalf("expected script_timeout, got %#v", err) + } +} + +func TestExecutorRejectedPromiseMessage(t *testing.T) { + _, err := (Executor{}).Execute(context.Background(), Options{ + Script: `async function main() { throw new Error("boom"); }`, + ScriptName: "custom_submit_script", + Timeout: time.Second, + }) + if err == nil { + t.Fatal("expected rejection") + } + scriptErr, ok := err.(*Error) + if !ok || scriptErr.Code != "script_error" || !strings.Contains(scriptErr.Message, "boom") { + t.Fatalf("expected script_error with boom, got %#v", err) + } +} diff --git a/apps/api/migrations/0039_exclude_easyai_media_catalog.sql b/apps/api/migrations/0039_exclude_easyai_media_catalog.sql new file mode 100644 index 0000000..adb65d4 --- /dev/null +++ b/apps/api/migrations/0039_exclude_easyai_media_catalog.sql @@ -0,0 +1,19 @@ +-- EasyAI/server-main is intentionally not migrated as an AI Gateway runtime +-- provider. Keep its historical catalog rows for traceability, but hide them +-- from fresh admin selection and mark the exclusion reason explicitly. +UPDATE base_model_catalog +SET status = 'deprecated', + metadata = COALESCE(metadata, '{}'::jsonb) || jsonb_build_object( + 'selectable', false, + 'migrationExcludedReason', 'excluded from AI Gateway media runtime migration to avoid gateway-to-server-main loopback', + 'migrationExcludedAt', '0039_exclude_easyai_media_catalog' + ) +WHERE provider_key = 'easyai' + AND model_type ?| ARRAY[ + 'image_generate', + 'image_edit', + 'video_generate', + 'image_to_video', + 'omni_video', + 'video_edit' + ]; diff --git a/docs/media-client-migration.md b/docs/media-client-migration.md new file mode 100644 index 0000000..f008b25 --- /dev/null +++ b/docs/media-client-migration.md @@ -0,0 +1,34 @@ +# Media Client Migration Status + +This document tracks the server-main media runtime migration into the AI Gateway. + +## Runtime Scope + +- Included model types: `image_generate`, `image_edit`, `video_generate`, `image_to_video`, `omni_video`, `video_edit`. +- Excluded provider: `easyai`, because routing AI Gateway media tasks back into server-main would create a loopback dependency. +- Universal custom scripts are supported through `integration_platforms.config`: + - `customPreprocessScript` + - `customGetParamsScript` + - `customSubmitScript` + - `customPollScript` + - `getTaskURL` + - `skipParamNormalization` + +## Migrated Clients + +- `universal`: custom preprocess/get params/submit/poll scripts, default submit/poll, remote task resume. +- `jimeng`: async submit/poll skeleton with Jimeng task id and status mapping. +- `blackforest`: submit with `x-key`, `polling_url` polling, image result normalization. +- `tencent-hunyuan-image`: Tencent-style `Response.JobId`/`Response.Status` image task mapping. +- `tencent-hunyuan-video`: Tencent-style `Response.JobId`/`Response.Status` video task mapping. +- `minimax`: video submit/query task mapping. +- `midjourney`: diffusion submit, job polling, original and Aliyun-style status/result mapping. +- `vidu`: Token auth, typed submit path, creations polling. +- `aliyun-bailian`: video synthesis submit and task polling. +- `newapi`: `/videos/generations` submit and task polling. + +## Notes + +- Provider-specific advanced parameter shaping remains isolated inside each client/spec. +- Tencent and Jimeng production deployments should configure exact submit/poll paths and credentials in platform config when they differ from the default server-main-compatible paths. +- Each migrated client has an `httptest` submit/poll coverage case in `internal/clients`.