easyai-ai-gateway/apps/api/internal/clients/provider_task.go

454 lines
15 KiB
Go

package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
type providerTaskSpec struct {
Name string
SubmitPath func(Request, map[string]any) string
PollPath func(Request, string, map[string]any) string
Auth string
TaskIDPaths []string
StatusPaths []string
SuccessStatuses []string
FailureStatuses []string
ProcessStatuses []string
DefaultSubmitBody func(Request, map[string]any) map[string]any
}
type providerTaskClient struct {
HTTPClient *http.Client
Spec providerTaskSpec
}
func (c providerTaskClient) Run(ctx context.Context, request Request) (Response, error) {
if request.Kind != "images.generations" && request.Kind != "images.edits" && request.Kind != "videos.generations" {
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported " + c.Spec.Name + " request kind", Retryable: false}
}
startedAt := time.Now()
payload := cloneBody(request.Body)
if c.Spec.DefaultSubmitBody != nil {
payload = c.Spec.DefaultSubmitBody(request, payload)
} else {
payload["model"] = upstreamModelName(request.Candidate)
}
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
requestID := upstreamTaskID
var submitResult map[string]any
if upstreamTaskID == "" {
result, id, err := c.submit(ctx, request, payload)
if err != nil {
return Response{}, annotateResponseError(err, id, startedAt, time.Now())
}
submitResult = result
requestID = firstNonEmptyString(id, requestIDFromResult(result))
if isProviderTaskFailure(c.Spec, result) {
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
}
if isProviderTaskSuccess(c.Spec, result) && hasProviderTaskResult(result) {
return Response{
Result: normalizeProviderTaskResult(request, c.Spec, result, ""),
RequestID: requestID,
Progress: providerProgress(request),
ResponseStartedAt: startedAt,
ResponseFinishedAt: time.Now(),
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
}, nil
}
upstreamTaskID = providerTaskID(c.Spec, result)
if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: c.Spec.Name + " task id is missing", RequestID: requestID, Retryable: false}
}
if request.OnRemoteTaskSubmitted != nil {
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
return Response{}, err
}
}
} else if request.RemoteTaskPayload != nil {
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
payload = existingPayload
}
}
interval := providerPollInterval(request)
timeout := providerPollTimeout(request)
deadline := time.NewTimer(timeout)
defer deadline.Stop()
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastResult map[string]any
for {
pollStarted := time.Now()
result, pollRequestID, err := c.poll(ctx, request, upstreamTaskID, payload)
pollFinished := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
}
lastResult = result
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
if isProviderTaskSuccess(c.Spec, result) {
finishedAt := time.Now()
return Response{
Result: normalizeProviderTaskResult(request, c.Spec, result, upstreamTaskID),
RequestID: requestID,
Progress: append(providerProgress(request), Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}),
ResponseStartedAt: startedAt,
ResponseFinishedAt: finishedAt,
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
}, nil
}
if isProviderTaskFailure(c.Spec, result) {
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
}
select {
case <-ctx.Done():
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
case <-deadline.C:
return Response{}, &ClientError{Code: "timeout", Message: fmt.Sprintf("%s task %s did not finish before timeout; last status: %s", c.Spec.Name, upstreamTaskID, providerTaskStatus(c.Spec, lastResult)), RequestID: requestID, Retryable: true}
case <-ticker.C:
}
}
}
func (c providerTaskClient) submit(ctx context.Context, request Request, payload map[string]any) (map[string]any, string, error) {
path := c.Spec.SubmitPath(request, payload)
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, path), payload, request.Candidate.Credentials, c.Spec.Auth)
}
func (c providerTaskClient) poll(ctx context.Context, request Request, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
path := resolveProviderPathTemplate(c.Spec.PollPath(request, upstreamTaskID, payload), upstreamTaskID)
url := path
if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") {
url = providerURL(request.Candidate.BaseURL, path)
}
if c.Spec.Name == "jimeng" {
body := map[string]any{"task_id": upstreamTaskID, "req_key": upstreamModelName(request.Candidate)}
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, body, request.Candidate.Credentials, c.Spec.Auth)
}
return providerGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, request.Candidate.Credentials, c.Spec.Auth)
}
func providerPostJSON(ctx context.Context, client *http.Client, url string, body map[string]any, credentials map[string]any, auth string) (map[string]any, string, error) {
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
if err != nil {
return nil, "", err
}
req.Header.Set("Content-Type", "application/json")
applyProviderAuth(req, credentials, auth)
resp, err := client.Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
return result, requestID, err
}
func providerGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any, auth string) (map[string]any, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, "", err
}
applyProviderAuth(req, credentials, auth)
resp, err := client.Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
return result, requestID, err
}
func applyProviderAuth(req *http.Request, credentials map[string]any, auth string) {
apiKey := credential(credentials, "apiKey", "api_key", "key", "token")
switch auth {
case "token":
if apiKey != "" {
req.Header.Set("Authorization", "Token "+apiKey)
}
case "x-key":
if apiKey != "" {
req.Header.Set("x-key", apiKey)
}
case "none":
default:
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
}
}
func providerURL(base string, path string) string {
path = strings.TrimSpace(path)
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
if path == "" {
path = "/"
}
if !strings.HasPrefix(path, "/") && !strings.HasPrefix(path, "?") {
path = "/" + path
}
return joinURL(base, path)
}
func resolveProviderPathTemplate(path string, upstreamTaskID string) string {
replacements := [][2]string{
{"${upstream_task_id}", upstreamTaskID},
{"{{upstream_task_id}}", upstreamTaskID},
{"{upstream_task_id}", upstreamTaskID},
{"${task_id}", upstreamTaskID},
{"{{task_id}}", upstreamTaskID},
{"{task_id}", upstreamTaskID},
{"${taskId}", upstreamTaskID},
{"${taskID}", upstreamTaskID},
{"{{taskId}}", upstreamTaskID},
{"{{taskID}}", upstreamTaskID},
{"{taskId}", upstreamTaskID},
{"{taskID}", upstreamTaskID},
}
for _, replacement := range replacements {
path = strings.ReplaceAll(path, replacement[0], replacement[1])
}
return path
}
func providerTaskID(spec providerTaskSpec, result map[string]any) string {
paths := append([]string{}, spec.TaskIDPaths...)
paths = append(paths, "task_id", "taskId", "id", "job_id", "Response.JobId", "output.task_id", "data.task_id", "polling_url")
for _, path := range paths {
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
return value
}
}
return ""
}
func providerTaskStatus(spec providerTaskSpec, result map[string]any) string {
if result == nil {
return ""
}
if value, ok := valueAtPath(result, "status").(float64); ok {
if value == 2 {
return "success"
}
if value == 3 {
return "failed"
}
return "process"
}
paths := append([]string{}, spec.StatusPaths...)
paths = append(paths, "status", "state", "task_status", "output.task_status", "Response.Status", "data.status")
for _, path := range paths {
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
return strings.ToLower(value)
}
}
return ""
}
func stringFromPathValue(value any) string {
if value == nil {
return ""
}
text := strings.TrimSpace(fmt.Sprint(value))
if text == "" || text == "<nil>" {
return ""
}
return text
}
func isProviderTaskSuccess(spec providerTaskSpec, result map[string]any) bool {
return containsStatus(append([]string{"success", "succeeded", "completed", "complete", "done", "ready", "succeed", "succeeded", "suceeded", "done", "done"}, spec.SuccessStatuses...), providerTaskStatus(spec, result))
}
func isProviderTaskFailure(spec providerTaskSpec, result map[string]any) bool {
return containsStatus(append([]string{"failed", "failure", "error", "cancelled", "canceled", "fail", "expired", "task not found"}, spec.FailureStatuses...), providerTaskStatus(spec, result))
}
func containsStatus(values []string, status string) bool {
status = strings.ToLower(strings.TrimSpace(status))
for _, value := range values {
if strings.ToLower(strings.TrimSpace(value)) == status {
return true
}
}
return false
}
func hasProviderTaskResult(result map[string]any) bool {
return result["data"] != nil || valueAtPath(result, "output.image_urls") != nil || valueAtPath(result, "output.video_url") != nil || valueAtPath(result, "Response.ResultVideoUrl") != nil || valueAtPath(result, "Response.ResultImages") != nil || result["urls"] != nil
}
func normalizeProviderTaskResult(request Request, spec providerTaskSpec, result map[string]any, upstreamTaskID string) map[string]any {
out := cloneMapAny(result)
out["status"] = "success"
if upstreamTaskID != "" {
out["upstream_task_id"] = upstreamTaskID
}
if out["created"] == nil {
out["created"] = time.Now().UnixMilli()
}
if out["model"] == nil {
out["model"] = request.Model
}
if _, ok := out["data"].([]any); !ok {
if out["data"] != nil {
out["raw_data"] = out["data"]
}
out["data"] = providerTaskData(request, result)
}
return out
}
func providerTaskData(request Request, result map[string]any) []any {
fileType := "image"
if request.Kind == "videos.generations" || strings.Contains(request.ModelType, "video") {
fileType = "video"
}
urlValues := []any{}
for _, path := range []string{
"urls",
"image_urls",
"data.image_urls",
"data.images",
"output.image_urls",
"output.video_url",
"output.output",
"data.output",
"data.video_url",
"video_url",
"preview_url",
"Response.ResultImages",
"Response.ResultVideoUrl",
} {
appendURLValues(&urlValues, valueAtPath(result, path))
}
data := make([]any, 0, len(urlValues))
for _, raw := range urlValues {
if url := strings.TrimSpace(fmt.Sprint(raw)); url != "" {
data = append(data, map[string]any{"type": fileType, "url": url})
}
}
if len(data) == 0 {
if base64Values := valueAtPath(result, "data.binary_data_base64"); base64Values != nil {
values := []any{}
appendURLValues(&values, base64Values)
for _, raw := range values {
if content := strings.TrimSpace(fmt.Sprint(raw)); content != "" {
data = append(data, map[string]any{"type": fileType, "content": content, "uploaded": false})
}
}
}
}
return data
}
func appendURLValues(out *[]any, value any) {
switch typed := value.(type) {
case nil:
case string:
*out = append(*out, typed)
case []any:
for _, item := range typed {
appendURLValues(out, item)
}
case []string:
for _, item := range typed {
*out = append(*out, item)
}
case map[string]any:
for _, key := range []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "content", "output"} {
if item := strings.TrimSpace(fmt.Sprint(typed[key])); item != "" && item != "<nil>" {
*out = append(*out, item)
return
}
}
}
}
func providerTaskFailure(spec providerTaskSpec, result map[string]any, requestID string, startedAt time.Time) error {
message := firstNonEmptyString(valueAtPath(result, "message"), valueAtPath(result, "error.message"), valueAtPath(result, "error"), valueAtPath(result, "Response.ErrorMessage"), valueAtPath(result, "comment"), spec.Name+" task failed")
return &ClientError{
Code: firstNonEmptyString(valueAtPath(result, "code"), valueAtPath(result, "error_code"), valueAtPath(result, "Response.ErrorCode"), "provider_failed"),
Message: message,
RequestID: requestID,
ResponseStartedAt: startedAt,
ResponseFinishedAt: time.Now(),
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
Retryable: false,
}
}
func providerPollInterval(request Request) time.Duration {
return durationFromConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
}
func providerPollTimeout(request Request) time.Duration {
return durationFromConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
}
func durationFromConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
for _, key := range keys {
switch value := config[key].(type) {
case int:
if value > 0 {
return time.Duration(value) * time.Millisecond
}
case int64:
if value > 0 {
return time.Duration(value) * time.Millisecond
}
case float64:
if value > 0 {
return time.Duration(value) * time.Millisecond
}
case string:
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
return parsed
}
}
}
return fallback
}
func valueAtPath(values map[string]any, path string) any {
if values == nil || strings.TrimSpace(path) == "" {
return nil
}
var current any = values
for _, part := range strings.Split(path, ".") {
object, ok := current.(map[string]any)
if !ok {
return nil
}
current = object[part]
}
return current
}
func mediaPromptText(body map[string]any) string {
if prompt := strings.TrimSpace(stringFromAny(body["prompt"])); prompt != "" {
return prompt
}
content, _ := body["content"].([]any)
for _, item := range content {
if part, ok := item.(map[string]any); ok && strings.TrimSpace(stringFromAny(part["type"])) == "text" {
if text := strings.TrimSpace(stringFromAny(part["text"])); text != "" {
return text
}
}
}
return ""
}