package clients import ( "bytes" "context" "encoding/json" "errors" "fmt" "net/http" "strings" "time" scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script" ) type UniversalClient struct { HTTPClient *http.Client ScriptExecutor *scriptengine.Executor } func (c UniversalClient) Run(ctx context.Context, request Request) (Response, error) { executor := c.ScriptExecutor if executor == nil { executor = &scriptengine.Executor{} } startedAt := time.Now() modelType := strings.TrimSpace(request.ModelType) if modelType == "" { modelType = strings.TrimSpace(request.Candidate.ModelType) } payload := cloneBody(request.Body) upstreamTaskID := strings.TrimSpace(request.RemoteTaskID) submitRequestID := upstreamTaskID var submitResult map[string]any if upstreamTaskID == "" { var err error payload, err = c.universalGetParams(ctx, executor, request, modelType) if err != nil { return Response{}, err } submitResult, submitRequestID, err = c.universalSubmit(ctx, executor, request, modelType, payload) if err != nil { return Response{}, annotateResponseError(err, submitRequestID, startedAt, time.Now()) } if isUniversalSuccess(submitResult) && submitResult["data"] != nil { return Response{ Result: normalizeUniversalResult(request, submitResult, ""), RequestID: firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), Progress: providerProgress(request), ResponseStartedAt: startedAt, ResponseFinishedAt: time.Now(), ResponseDurationMS: responseDurationMS(startedAt, time.Now()), }, nil } if isUniversalFailure(submitResult) { return Response{}, universalFailureError(submitResult, firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), startedAt) } upstreamTaskID = universalTaskID(submitResult) if upstreamTaskID == "" { return Response{}, &ClientError{Code: "invalid_response", Message: "universal task id is missing", RequestID: submitRequestID, Retryable: false} } if request.OnRemoteTaskSubmitted != nil { if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil { return Response{}, err } } } else if request.RemoteTaskPayload != nil { if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok { payload = existingPayload } } result, requestID, err := c.universalPollUntilDone(ctx, executor, request, modelType, upstreamTaskID, payload, firstNonEmptyString(submitRequestID, upstreamTaskID), startedAt) if err != nil { return Response{}, err } finishedAt := time.Now() return Response{ Result: normalizeUniversalResult(request, result, upstreamTaskID), RequestID: firstNonEmptyString(requestID, submitRequestID, requestIDFromResult(result), upstreamTaskID), Progress: universalProgress(request, upstreamTaskID), ResponseStartedAt: startedAt, ResponseFinishedAt: finishedAt, ResponseDurationMS: responseDurationMS(startedAt, finishedAt), }, nil } func (c UniversalClient) universalGetParams(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string) (map[string]any, error) { if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customGetParamsScript", "custom_get_params_script"); scriptText != "" { scriptContext := universalScriptContext(request, modelType, nil) out, err := executor.Execute(ctx, scriptengine.Options{ Script: scriptText, Args: []any{cloneBody(request.Body), scriptContext}, ContextData: scriptContext, ScriptName: "custom_get_params_script:" + modelType, PreferredEntryNames: []string{"getGenerateParams", "getParams", "main", "handler"}, Timeout: 30 * time.Second, HTTPClient: httpClient(request.HTTPClient, c.HTTPClient), }) if err != nil { return nil, universalScriptError(err) } if params, ok := out.(map[string]any); ok && params != nil { if params["_originalParams"] == nil { params["_originalParams"] = cloneBody(request.Body) } return params, nil } return nil, &ClientError{Code: "invalid_response", Message: "custom get params script must return an object", Retryable: false} } body := universalDefaultPayload(request) body["_originalParams"] = cloneBody(request.Body) return body, nil } func (c UniversalClient) universalSubmit(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, payload map[string]any) (map[string]any, string, error) { if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customSubmitScript", "custom_submit_script"); scriptText != "" { scriptContext := universalScriptContext(request, modelType, payload) out, err := executor.Execute(ctx, scriptengine.Options{ Script: scriptText, Args: []any{cloneBody(payload), scriptContext}, ContextData: scriptContext, ScriptName: "custom_submit_script:" + modelType, PreferredEntryNames: []string{"submitTask", "submitParams", "submit", "main", "handler"}, Timeout: 30 * time.Second, HTTPClient: httpClient(request.HTTPClient, c.HTTPClient), }) if err != nil { return nil, "", universalScriptError(err) } result, ok := out.(map[string]any) if !ok || result == nil { return nil, "", &ClientError{Code: "invalid_response", Message: "custom submit script must return an object", Retryable: false} } return result, requestIDFromResult(result), nil } endpoint := universalSubmitEndpoint(request) result, requestID, err := universalPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), request.Candidate.BaseURL, endpoint, universalStripPrivatePayload(payload), request.Candidate.Credentials) return result, requestID, err } func (c UniversalClient) universalPollUntilDone(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any, requestID string, startedAt time.Time) (map[string]any, string, error) { interval := universalDurationConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms") timeout := universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs") deadline := time.NewTimer(timeout) defer deadline.Stop() ticker := time.NewTicker(interval) defer ticker.Stop() var lastResult map[string]any for { pollStarted := time.Now() result, pollRequestID, err := c.universalPoll(ctx, executor, request, modelType, upstreamTaskID, payload) pollFinished := time.Now() if err != nil { return nil, "", annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished) } lastResult = result requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID) if isUniversalSuccess(result) { return result, requestID, nil } if isUniversalFailure(result) { return nil, "", universalFailureError(result, requestID, startedAt) } select { case <-ctx.Done(): return nil, "", &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true} case <-deadline.C: return nil, "", &ClientError{Code: "timeout", Message: fmt.Sprintf("universal task %s did not finish before timeout; last status: %s", upstreamTaskID, universalStatus(lastResult)), RequestID: requestID, Retryable: true} case <-ticker.C: } } } func (c UniversalClient) universalPoll(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) { if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customPollScript", "custom_poll_script"); scriptText != "" { scriptContext := universalScriptContext(request, modelType, payload) out, err := executor.Execute(ctx, scriptengine.Options{ Script: scriptText, Args: []any{upstreamTaskID, scriptContext}, ContextData: scriptContext, ScriptName: "custom_poll_script:" + modelType, PreferredEntryNames: []string{"pollTask", "poll", "main", "handler"}, Timeout: 30 * time.Second, HTTPClient: httpClient(request.HTTPClient, c.HTTPClient), }) if err != nil { return nil, "", universalScriptError(err) } result, ok := out.(map[string]any) if !ok || result == nil { return nil, "", &ClientError{Code: "invalid_response", Message: "custom poll script must return an object", Retryable: false} } return result, requestIDFromResult(result), nil } pollURL := resolveUniversalTaskURL(request.Candidate.PlatformConfig, upstreamTaskID) if pollURL == "" { return nil, "", &ClientError{Code: "missing_configuration", Message: "universal getTaskURL is required", Retryable: false} } return universalGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), pollURL, request.Candidate.Credentials) } func universalScriptContext(request Request, modelType string, payload map[string]any) map[string]any { baseURL := strings.TrimRight(strings.TrimSpace(request.Candidate.BaseURL), "/") getTaskURL := universalConfigString(request.Candidate.PlatformConfig, "getTaskURL", "get_task_url") context := map[string]any{ "__easyaiScriptContext": true, "baseURL": baseURL, "getTaskURL": getTaskURL, "authValues": cloneMapAny(request.Candidate.Credentials), "headers": map[string]any{}, "payload": cloneMapAny(payload), "type": modelType, "options": map[string]any{ "task_id": request.RemoteTaskID, "upstream_task_id": request.RemoteTaskID, "model": request.Model, "providerModelName": request.Candidate.ProviderModelName, "platformId": request.Candidate.PlatformID, "platformModelId": request.Candidate.PlatformModelID, "canonicalModelKey": request.Candidate.CanonicalModelKey, "modelType": modelType, "timeout": universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms").Milliseconds(), }, "env": cloneMapAny(request.Candidate.PlatformConfig), "candidate": universalCandidateSnapshot(request), } context["createRequestURL"] = func(path string, base ...string) string { selectedBase := baseURL if len(base) > 0 && strings.TrimSpace(base[0]) != "" { selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/") } if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { return path } return selectedBase + "/" + strings.TrimLeft(path, "/") } context["creatRequestURL"] = context["createRequestURL"] context["resolveGetTaskURL"] = func(taskID string) string { return resolveUniversalTaskURL(request.Candidate.PlatformConfig, taskID) } return context } func universalCandidateSnapshot(request Request) map[string]any { return map[string]any{ "modelName": request.Candidate.ModelName, "modelAlias": request.Candidate.ModelAlias, "providerModelName": request.Candidate.ProviderModelName, "provider": request.Candidate.Provider, "platformId": request.Candidate.PlatformID, "platformModelId": request.Candidate.PlatformModelID, "capabilities": cloneMapAny(request.Candidate.Capabilities), } } func universalDefaultPayload(request Request) map[string]any { body := cloneBody(request.Body) body["model"] = upstreamModelName(request.Candidate) if request.Kind == "images.generations" { if n := firstPresent(body["n"], body["numImages"]); n != nil { body["numImages"] = n } if aspectRatio := strings.TrimSpace(stringFromAny(body["aspect_ratio"])); aspectRatio != "" { body["aspectRatio"] = aspectRatio } } return body } func universalSubmitEndpoint(request Request) string { if endpoint := universalConfigString(request.Candidate.PlatformConfig, "submitPath", "submit_path"); endpoint != "" { return endpoint } switch request.Kind { case "images.generations": return "/images/generations" case "images.edits": return "/images/edits" case "videos.generations": return "/video/generations" default: return "/" + strings.ReplaceAll(request.Kind, ".", "/") } } func universalPostJSON(ctx context.Context, client *http.Client, baseURL string, endpoint string, body map[string]any, credentials map[string]any) (map[string]any, string, error) { raw, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, http.MethodPost, providerURL(baseURL, endpoint), bytes.NewReader(raw)) if err != nil { return nil, "", err } req.Header.Set("Content-Type", "application/json") if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" { req.Header.Set("Authorization", "Bearer "+apiKey) } resp, err := client.Do(req) if err != nil { return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} } requestID := requestIDFromHTTPResponse(resp) result, err := decodeHTTPResponse(resp) return result, requestID, err } func universalGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any) (map[string]any, string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, "", err } if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" { req.Header.Set("Authorization", "Bearer "+apiKey) } resp, err := client.Do(req) if err != nil { return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} } requestID := requestIDFromHTTPResponse(resp) result, err := decodeHTTPResponse(resp) return result, requestID, err } func normalizeUniversalResult(request Request, result map[string]any, upstreamTaskID string) map[string]any { out := cloneMapAny(result) if out["created"] == nil { out["created"] = time.Now().UnixMilli() } if out["task_id"] == nil { out["task_id"] = upstreamTaskID } if out["upstream_task_id"] == nil { out["upstream_task_id"] = upstreamTaskID } if out["model"] == nil { out["model"] = request.Model } if out["status"] == nil { out["status"] = "success" } return out } func universalScriptError(err error) error { var scriptErr *scriptengine.Error if strings.TrimSpace(err.Error()) == "" { return &ClientError{Code: "script_error", Message: "script execution failed", Retryable: false} } if errors.As(err, &scriptErr) { return &ClientError{Code: scriptErr.ErrorCode(), Message: scriptErr.Error(), Retryable: scriptErr.ErrorCode() == "script_timeout"} } return &ClientError{Code: "script_error", Message: err.Error(), Retryable: false} } func universalFailureError(result map[string]any, requestID string, startedAt time.Time) error { message := firstNonEmptyString(result["message"], result["error"], result["error_message"], "universal task failed") return &ClientError{ Code: firstNonEmptyString(result["code"], result["error_code"], "provider_failed"), Message: message, RequestID: requestID, ResponseStartedAt: startedAt, ResponseFinishedAt: time.Now(), ResponseDurationMS: responseDurationMS(startedAt, time.Now()), Retryable: false, } } func isUniversalSuccess(result map[string]any) bool { switch universalStatus(result) { case "success", "succeeded", "completed", "complete", "done": return true default: return false } } func isUniversalFailure(result map[string]any) bool { switch universalStatus(result) { case "failed", "failure", "error", "cancelled", "canceled": return true default: return false } } func universalStatus(result map[string]any) string { return strings.ToLower(strings.TrimSpace(firstNonEmptyString(result["status"], result["state"], result["task_status"]))) } func universalTaskID(result map[string]any) string { return firstNonEmptyString(result["upstream_task_id"], result["task_id"], result["taskId"], result["id"]) } func universalProgress(request Request, upstreamTaskID string) []Progress { progress := providerProgress(request) progress = append(progress, Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}) return progress } func universalStripPrivatePayload(payload map[string]any) map[string]any { out := cloneMapAny(payload) for _, key := range []string{"_originalParams", "_resolution", "_duration"} { delete(out, key) } return out } func universalSceneScript(config map[string]any, modelType string, keys ...string) string { for _, key := range keys { value := config[key] switch typed := value.(type) { case string: if strings.TrimSpace(typed) != "" { return strings.TrimSpace(typed) } case map[string]any: if script := firstNonEmptyString(typed[modelType], typed["common"]); script != "" { return script } } } return "" } func universalConfigString(config map[string]any, keys ...string) string { for _, key := range keys { if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "" { return value } } return "" } func universalDurationConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration { for _, key := range keys { switch value := config[key].(type) { case int: if value > 0 { return time.Duration(value) * time.Millisecond } case int64: if value > 0 { return time.Duration(value) * time.Millisecond } case float64: if value > 0 { return time.Duration(value) * time.Millisecond } case string: if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 { return parsed } } } return fallback } func resolveUniversalTaskURL(config map[string]any, upstreamTaskID string) string { template := universalConfigString(config, "getTaskURL", "get_task_url") out := strings.TrimSpace(template) replacements := [][2]string{ {"${upstream_task_id}", upstreamTaskID}, {"{{upstream_task_id}}", upstreamTaskID}, {"{upstream_task_id}", upstreamTaskID}, {"${task_id}", upstreamTaskID}, {"{{task_id}}", upstreamTaskID}, {"{task_id}", upstreamTaskID}, {"${taskId}", upstreamTaskID}, {"${taskID}", upstreamTaskID}, {"{{taskId}}", upstreamTaskID}, {"{{taskID}}", upstreamTaskID}, {"{taskId}", upstreamTaskID}, {"{taskID}", upstreamTaskID}, } for _, replacement := range replacements { out = strings.ReplaceAll(out, replacement[0], replacement[1]) } return out }