package clients import ( "bytes" "context" "encoding/json" "fmt" "net/http" "strings" "time" ) type providerTaskSpec struct { Name string SubmitPath func(Request, map[string]any) string PollPath func(Request, string, map[string]any) string Auth string TaskIDPaths []string StatusPaths []string SuccessStatuses []string FailureStatuses []string ProcessStatuses []string DefaultSubmitBody func(Request, map[string]any) map[string]any } type providerTaskClient struct { HTTPClient *http.Client Spec providerTaskSpec } func (c providerTaskClient) Run(ctx context.Context, request Request) (Response, error) { if request.Kind != "images.generations" && request.Kind != "images.edits" && request.Kind != "videos.generations" { return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported " + c.Spec.Name + " request kind", Retryable: false} } startedAt := time.Now() payload := cloneBody(request.Body) if c.Spec.DefaultSubmitBody != nil { payload = c.Spec.DefaultSubmitBody(request, payload) } else { payload["model"] = upstreamModelName(request.Candidate) } upstreamTaskID := strings.TrimSpace(request.RemoteTaskID) requestID := upstreamTaskID var submitResult map[string]any if upstreamTaskID == "" { result, id, err := c.submit(ctx, request, payload) if err != nil { return Response{}, annotateResponseError(err, id, startedAt, time.Now()) } submitResult = result requestID = firstNonEmptyString(id, requestIDFromResult(result)) if isProviderTaskFailure(c.Spec, result) { return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt) } if isProviderTaskSuccess(c.Spec, result) && hasProviderTaskResult(result) { return Response{ Result: normalizeProviderTaskResult(request, c.Spec, result, ""), RequestID: requestID, Progress: providerProgress(request), ResponseStartedAt: startedAt, ResponseFinishedAt: time.Now(), ResponseDurationMS: responseDurationMS(startedAt, time.Now()), }, nil } upstreamTaskID = providerTaskID(c.Spec, result) if upstreamTaskID == "" { return Response{}, &ClientError{Code: "invalid_response", Message: c.Spec.Name + " task id is missing", RequestID: requestID, Retryable: false} } if request.OnRemoteTaskSubmitted != nil { if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil { return Response{}, err } } } else if request.RemoteTaskPayload != nil { if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok { payload = existingPayload } } interval := providerPollInterval(request) timeout := providerPollTimeout(request) deadline := time.NewTimer(timeout) defer deadline.Stop() ticker := time.NewTicker(interval) defer ticker.Stop() var lastResult map[string]any for { pollStarted := time.Now() result, pollRequestID, err := c.poll(ctx, request, upstreamTaskID, payload) pollFinished := time.Now() if err != nil { return Response{}, annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished) } lastResult = result requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID) if isProviderTaskSuccess(c.Spec, result) { finishedAt := time.Now() return Response{ Result: normalizeProviderTaskResult(request, c.Spec, result, upstreamTaskID), RequestID: requestID, Progress: append(providerProgress(request), Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}), ResponseStartedAt: startedAt, ResponseFinishedAt: finishedAt, ResponseDurationMS: responseDurationMS(startedAt, finishedAt), }, nil } if isProviderTaskFailure(c.Spec, result) { return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt) } select { case <-ctx.Done(): return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true} case <-deadline.C: return Response{}, &ClientError{Code: "timeout", Message: fmt.Sprintf("%s task %s did not finish before timeout; last status: %s", c.Spec.Name, upstreamTaskID, providerTaskStatus(c.Spec, lastResult)), RequestID: requestID, Retryable: true} case <-ticker.C: } } } func (c providerTaskClient) submit(ctx context.Context, request Request, payload map[string]any) (map[string]any, string, error) { path := c.Spec.SubmitPath(request, payload) return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, path), payload, request.Candidate.Credentials, c.Spec.Auth) } func (c providerTaskClient) poll(ctx context.Context, request Request, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) { path := resolveProviderPathTemplate(c.Spec.PollPath(request, upstreamTaskID, payload), upstreamTaskID) url := path if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") { url = providerURL(request.Candidate.BaseURL, path) } if c.Spec.Name == "jimeng" { body := map[string]any{"task_id": upstreamTaskID, "req_key": upstreamModelName(request.Candidate)} return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, body, request.Candidate.Credentials, c.Spec.Auth) } return providerGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, request.Candidate.Credentials, c.Spec.Auth) } func providerPostJSON(ctx context.Context, client *http.Client, url string, body map[string]any, credentials map[string]any, auth string) (map[string]any, string, error) { raw, _ := json.Marshal(body) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw)) if err != nil { return nil, "", err } req.Header.Set("Content-Type", "application/json") applyProviderAuth(req, credentials, auth) resp, err := client.Do(req) if err != nil { return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} } requestID := requestIDFromHTTPResponse(resp) result, err := decodeHTTPResponse(resp) return result, requestID, err } func providerGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any, auth string) (map[string]any, string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, "", err } applyProviderAuth(req, credentials, auth) resp, err := client.Do(req) if err != nil { return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true} } requestID := requestIDFromHTTPResponse(resp) result, err := decodeHTTPResponse(resp) return result, requestID, err } func applyProviderAuth(req *http.Request, credentials map[string]any, auth string) { apiKey := credential(credentials, "apiKey", "api_key", "key", "token") switch auth { case "token": if apiKey != "" { req.Header.Set("Authorization", "Token "+apiKey) } case "x-key": if apiKey != "" { req.Header.Set("x-key", apiKey) } case "none": default: if apiKey != "" { req.Header.Set("Authorization", "Bearer "+apiKey) } } } func providerURL(base string, path string) string { path = strings.TrimSpace(path) if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { return path } if path == "" { path = "/" } if !strings.HasPrefix(path, "/") && !strings.HasPrefix(path, "?") { path = "/" + path } return joinURL(base, path) } func resolveProviderPathTemplate(path string, upstreamTaskID string) string { replacements := [][2]string{ {"${upstream_task_id}", upstreamTaskID}, {"{{upstream_task_id}}", upstreamTaskID}, {"{upstream_task_id}", upstreamTaskID}, {"${task_id}", upstreamTaskID}, {"{{task_id}}", upstreamTaskID}, {"{task_id}", upstreamTaskID}, {"${taskId}", upstreamTaskID}, {"${taskID}", upstreamTaskID}, {"{{taskId}}", upstreamTaskID}, {"{{taskID}}", upstreamTaskID}, {"{taskId}", upstreamTaskID}, {"{taskID}", upstreamTaskID}, } for _, replacement := range replacements { path = strings.ReplaceAll(path, replacement[0], replacement[1]) } return path } func providerTaskID(spec providerTaskSpec, result map[string]any) string { paths := append([]string{}, spec.TaskIDPaths...) paths = append(paths, "task_id", "taskId", "id", "job_id", "Response.JobId", "output.task_id", "data.task_id", "polling_url") for _, path := range paths { if value := stringFromPathValue(valueAtPath(result, path)); value != "" { return value } } return "" } func providerTaskStatus(spec providerTaskSpec, result map[string]any) string { if result == nil { return "" } if value, ok := valueAtPath(result, "status").(float64); ok { if value == 2 { return "success" } if value == 3 { return "failed" } return "process" } paths := append([]string{}, spec.StatusPaths...) paths = append(paths, "status", "state", "task_status", "output.task_status", "Response.Status", "data.status") for _, path := range paths { if value := stringFromPathValue(valueAtPath(result, path)); value != "" { return strings.ToLower(value) } } return "" } func stringFromPathValue(value any) string { if value == nil { return "" } text := strings.TrimSpace(fmt.Sprint(value)) if text == "" || text == "" { return "" } return text } func isProviderTaskSuccess(spec providerTaskSpec, result map[string]any) bool { return containsStatus(append([]string{"success", "succeeded", "completed", "complete", "done", "ready", "succeed", "succeeded", "suceeded", "done", "done"}, spec.SuccessStatuses...), providerTaskStatus(spec, result)) } func isProviderTaskFailure(spec providerTaskSpec, result map[string]any) bool { return containsStatus(append([]string{"failed", "failure", "error", "cancelled", "canceled", "fail", "expired", "task not found"}, spec.FailureStatuses...), providerTaskStatus(spec, result)) } func containsStatus(values []string, status string) bool { status = strings.ToLower(strings.TrimSpace(status)) for _, value := range values { if strings.ToLower(strings.TrimSpace(value)) == status { return true } } return false } func hasProviderTaskResult(result map[string]any) bool { return result["data"] != nil || valueAtPath(result, "output.image_urls") != nil || valueAtPath(result, "output.video_url") != nil || valueAtPath(result, "Response.ResultVideoUrl") != nil || valueAtPath(result, "Response.ResultImages") != nil || result["urls"] != nil } func normalizeProviderTaskResult(request Request, spec providerTaskSpec, result map[string]any, upstreamTaskID string) map[string]any { out := cloneMapAny(result) out["status"] = "success" if upstreamTaskID != "" { out["upstream_task_id"] = upstreamTaskID } if out["created"] == nil { out["created"] = time.Now().UnixMilli() } if out["model"] == nil { out["model"] = request.Model } if _, ok := out["data"].([]any); !ok { if out["data"] != nil { out["raw_data"] = out["data"] } out["data"] = providerTaskData(request, result) } return out } func providerTaskData(request Request, result map[string]any) []any { fileType := "image" if request.Kind == "videos.generations" || strings.Contains(request.ModelType, "video") { fileType = "video" } urlValues := []any{} for _, path := range []string{ "urls", "image_urls", "data.image_urls", "data.images", "output.image_urls", "output.video_url", "output.output", "data.output", "data.video_url", "video_url", "preview_url", "Response.ResultImages", "Response.ResultVideoUrl", } { appendURLValues(&urlValues, valueAtPath(result, path)) } data := make([]any, 0, len(urlValues)) for _, raw := range urlValues { if url := strings.TrimSpace(fmt.Sprint(raw)); url != "" { data = append(data, map[string]any{"type": fileType, "url": url}) } } if len(data) == 0 { if base64Values := valueAtPath(result, "data.binary_data_base64"); base64Values != nil { values := []any{} appendURLValues(&values, base64Values) for _, raw := range values { if content := strings.TrimSpace(fmt.Sprint(raw)); content != "" { data = append(data, map[string]any{"type": fileType, "content": content, "uploaded": false}) } } } } return data } func appendURLValues(out *[]any, value any) { switch typed := value.(type) { case nil: case string: *out = append(*out, typed) case []any: for _, item := range typed { appendURLValues(out, item) } case []string: for _, item := range typed { *out = append(*out, item) } case map[string]any: for _, key := range []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "content", "output"} { if item := strings.TrimSpace(fmt.Sprint(typed[key])); item != "" && item != "" { *out = append(*out, item) return } } } } func providerTaskFailure(spec providerTaskSpec, result map[string]any, requestID string, startedAt time.Time) error { message := firstNonEmptyString(valueAtPath(result, "message"), valueAtPath(result, "error.message"), valueAtPath(result, "error"), valueAtPath(result, "Response.ErrorMessage"), valueAtPath(result, "comment"), spec.Name+" task failed") return &ClientError{ Code: firstNonEmptyString(valueAtPath(result, "code"), valueAtPath(result, "error_code"), valueAtPath(result, "Response.ErrorCode"), "provider_failed"), Message: message, RequestID: requestID, ResponseStartedAt: startedAt, ResponseFinishedAt: time.Now(), ResponseDurationMS: responseDurationMS(startedAt, time.Now()), Retryable: false, } } func providerPollInterval(request Request) time.Duration { return durationFromConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms") } func providerPollTimeout(request Request) time.Duration { return durationFromConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs") } func durationFromConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration { for _, key := range keys { switch value := config[key].(type) { case int: if value > 0 { return time.Duration(value) * time.Millisecond } case int64: if value > 0 { return time.Duration(value) * time.Millisecond } case float64: if value > 0 { return time.Duration(value) * time.Millisecond } case string: if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 { return parsed } } } return fallback } func valueAtPath(values map[string]any, path string) any { if values == nil || strings.TrimSpace(path) == "" { return nil } var current any = values for _, part := range strings.Split(path, ".") { object, ok := current.(map[string]any) if !ok { return nil } current = object[part] } return current } func mediaPromptText(body map[string]any) string { if prompt := strings.TrimSpace(stringFromAny(body["prompt"])); prompt != "" { return prompt } content, _ := body["content"].([]any) for _, item := range content { if part, ok := item.(map[string]any); ok && strings.TrimSpace(stringFromAny(part["type"])) == "text" { if text := strings.TrimSpace(stringFromAny(part["text"])); text != "" { return text } } } return "" }