454 lines
15 KiB
Go
454 lines
15 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type providerTaskSpec struct {
|
|
Name string
|
|
SubmitPath func(Request, map[string]any) string
|
|
PollPath func(Request, string, map[string]any) string
|
|
Auth string
|
|
TaskIDPaths []string
|
|
StatusPaths []string
|
|
SuccessStatuses []string
|
|
FailureStatuses []string
|
|
ProcessStatuses []string
|
|
DefaultSubmitBody func(Request, map[string]any) map[string]any
|
|
}
|
|
|
|
type providerTaskClient struct {
|
|
HTTPClient *http.Client
|
|
Spec providerTaskSpec
|
|
}
|
|
|
|
func (c providerTaskClient) Run(ctx context.Context, request Request) (Response, error) {
|
|
if request.Kind != "images.generations" && request.Kind != "images.edits" && request.Kind != "videos.generations" {
|
|
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported " + c.Spec.Name + " request kind", Retryable: false}
|
|
}
|
|
startedAt := time.Now()
|
|
payload := cloneBody(request.Body)
|
|
if c.Spec.DefaultSubmitBody != nil {
|
|
payload = c.Spec.DefaultSubmitBody(request, payload)
|
|
} else {
|
|
payload["model"] = upstreamModelName(request.Candidate)
|
|
}
|
|
|
|
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
|
requestID := upstreamTaskID
|
|
var submitResult map[string]any
|
|
if upstreamTaskID == "" {
|
|
result, id, err := c.submit(ctx, request, payload)
|
|
if err != nil {
|
|
return Response{}, annotateResponseError(err, id, startedAt, time.Now())
|
|
}
|
|
submitResult = result
|
|
requestID = firstNonEmptyString(id, requestIDFromResult(result))
|
|
if isProviderTaskFailure(c.Spec, result) {
|
|
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
|
|
}
|
|
if isProviderTaskSuccess(c.Spec, result) && hasProviderTaskResult(result) {
|
|
return Response{
|
|
Result: normalizeProviderTaskResult(request, c.Spec, result, ""),
|
|
RequestID: requestID,
|
|
Progress: providerProgress(request),
|
|
ResponseStartedAt: startedAt,
|
|
ResponseFinishedAt: time.Now(),
|
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
|
}, nil
|
|
}
|
|
upstreamTaskID = providerTaskID(c.Spec, result)
|
|
if upstreamTaskID == "" {
|
|
return Response{}, &ClientError{Code: "invalid_response", Message: c.Spec.Name + " task id is missing", RequestID: requestID, Retryable: false}
|
|
}
|
|
if request.OnRemoteTaskSubmitted != nil {
|
|
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
|
|
return Response{}, err
|
|
}
|
|
}
|
|
} else if request.RemoteTaskPayload != nil {
|
|
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
|
|
payload = existingPayload
|
|
}
|
|
}
|
|
|
|
interval := providerPollInterval(request)
|
|
timeout := providerPollTimeout(request)
|
|
deadline := time.NewTimer(timeout)
|
|
defer deadline.Stop()
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
var lastResult map[string]any
|
|
for {
|
|
pollStarted := time.Now()
|
|
result, pollRequestID, err := c.poll(ctx, request, upstreamTaskID, payload)
|
|
pollFinished := time.Now()
|
|
if err != nil {
|
|
return Response{}, annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
|
|
}
|
|
lastResult = result
|
|
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
|
|
if isProviderTaskSuccess(c.Spec, result) {
|
|
finishedAt := time.Now()
|
|
return Response{
|
|
Result: normalizeProviderTaskResult(request, c.Spec, result, upstreamTaskID),
|
|
RequestID: requestID,
|
|
Progress: append(providerProgress(request), Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}),
|
|
ResponseStartedAt: startedAt,
|
|
ResponseFinishedAt: finishedAt,
|
|
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
|
|
}, nil
|
|
}
|
|
if isProviderTaskFailure(c.Spec, result) {
|
|
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
|
case <-deadline.C:
|
|
return Response{}, &ClientError{Code: "timeout", Message: fmt.Sprintf("%s task %s did not finish before timeout; last status: %s", c.Spec.Name, upstreamTaskID, providerTaskStatus(c.Spec, lastResult)), RequestID: requestID, Retryable: true}
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c providerTaskClient) submit(ctx context.Context, request Request, payload map[string]any) (map[string]any, string, error) {
|
|
path := c.Spec.SubmitPath(request, payload)
|
|
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, path), payload, request.Candidate.Credentials, c.Spec.Auth)
|
|
}
|
|
|
|
func (c providerTaskClient) poll(ctx context.Context, request Request, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
|
|
path := resolveProviderPathTemplate(c.Spec.PollPath(request, upstreamTaskID, payload), upstreamTaskID)
|
|
url := path
|
|
if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") {
|
|
url = providerURL(request.Candidate.BaseURL, path)
|
|
}
|
|
if c.Spec.Name == "jimeng" {
|
|
body := map[string]any{"task_id": upstreamTaskID, "req_key": upstreamModelName(request.Candidate)}
|
|
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, body, request.Candidate.Credentials, c.Spec.Auth)
|
|
}
|
|
return providerGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, request.Candidate.Credentials, c.Spec.Auth)
|
|
}
|
|
|
|
func providerPostJSON(ctx context.Context, client *http.Client, url string, body map[string]any, credentials map[string]any, auth string) (map[string]any, string, error) {
|
|
raw, _ := json.Marshal(body)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
applyProviderAuth(req, credentials, auth)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
|
}
|
|
requestID := requestIDFromHTTPResponse(resp)
|
|
result, err := decodeHTTPResponse(resp)
|
|
return result, requestID, err
|
|
}
|
|
|
|
func providerGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any, auth string) (map[string]any, string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
applyProviderAuth(req, credentials, auth)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
|
}
|
|
requestID := requestIDFromHTTPResponse(resp)
|
|
result, err := decodeHTTPResponse(resp)
|
|
return result, requestID, err
|
|
}
|
|
|
|
func applyProviderAuth(req *http.Request, credentials map[string]any, auth string) {
|
|
apiKey := credential(credentials, "apiKey", "api_key", "key", "token")
|
|
switch auth {
|
|
case "token":
|
|
if apiKey != "" {
|
|
req.Header.Set("Authorization", "Token "+apiKey)
|
|
}
|
|
case "x-key":
|
|
if apiKey != "" {
|
|
req.Header.Set("x-key", apiKey)
|
|
}
|
|
case "none":
|
|
default:
|
|
if apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
}
|
|
}
|
|
}
|
|
|
|
func providerURL(base string, path string) string {
|
|
path = strings.TrimSpace(path)
|
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
|
return path
|
|
}
|
|
if path == "" {
|
|
path = "/"
|
|
}
|
|
if !strings.HasPrefix(path, "/") && !strings.HasPrefix(path, "?") {
|
|
path = "/" + path
|
|
}
|
|
return joinURL(base, path)
|
|
}
|
|
|
|
func resolveProviderPathTemplate(path string, upstreamTaskID string) string {
|
|
replacements := [][2]string{
|
|
{"${upstream_task_id}", upstreamTaskID},
|
|
{"{{upstream_task_id}}", upstreamTaskID},
|
|
{"{upstream_task_id}", upstreamTaskID},
|
|
{"${task_id}", upstreamTaskID},
|
|
{"{{task_id}}", upstreamTaskID},
|
|
{"{task_id}", upstreamTaskID},
|
|
{"${taskId}", upstreamTaskID},
|
|
{"${taskID}", upstreamTaskID},
|
|
{"{{taskId}}", upstreamTaskID},
|
|
{"{{taskID}}", upstreamTaskID},
|
|
{"{taskId}", upstreamTaskID},
|
|
{"{taskID}", upstreamTaskID},
|
|
}
|
|
for _, replacement := range replacements {
|
|
path = strings.ReplaceAll(path, replacement[0], replacement[1])
|
|
}
|
|
return path
|
|
}
|
|
|
|
func providerTaskID(spec providerTaskSpec, result map[string]any) string {
|
|
paths := append([]string{}, spec.TaskIDPaths...)
|
|
paths = append(paths, "task_id", "taskId", "id", "job_id", "Response.JobId", "output.task_id", "data.task_id", "polling_url")
|
|
for _, path := range paths {
|
|
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func providerTaskStatus(spec providerTaskSpec, result map[string]any) string {
|
|
if result == nil {
|
|
return ""
|
|
}
|
|
if value, ok := valueAtPath(result, "status").(float64); ok {
|
|
if value == 2 {
|
|
return "success"
|
|
}
|
|
if value == 3 {
|
|
return "failed"
|
|
}
|
|
return "process"
|
|
}
|
|
paths := append([]string{}, spec.StatusPaths...)
|
|
paths = append(paths, "status", "state", "task_status", "output.task_status", "Response.Status", "data.status")
|
|
for _, path := range paths {
|
|
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
|
|
return strings.ToLower(value)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func stringFromPathValue(value any) string {
|
|
if value == nil {
|
|
return ""
|
|
}
|
|
text := strings.TrimSpace(fmt.Sprint(value))
|
|
if text == "" || text == "<nil>" {
|
|
return ""
|
|
}
|
|
return text
|
|
}
|
|
|
|
func isProviderTaskSuccess(spec providerTaskSpec, result map[string]any) bool {
|
|
return containsStatus(append([]string{"success", "succeeded", "completed", "complete", "done", "ready", "succeed", "succeeded", "suceeded", "done", "done"}, spec.SuccessStatuses...), providerTaskStatus(spec, result))
|
|
}
|
|
|
|
func isProviderTaskFailure(spec providerTaskSpec, result map[string]any) bool {
|
|
return containsStatus(append([]string{"failed", "failure", "error", "cancelled", "canceled", "fail", "expired", "task not found"}, spec.FailureStatuses...), providerTaskStatus(spec, result))
|
|
}
|
|
|
|
func containsStatus(values []string, status string) bool {
|
|
status = strings.ToLower(strings.TrimSpace(status))
|
|
for _, value := range values {
|
|
if strings.ToLower(strings.TrimSpace(value)) == status {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func hasProviderTaskResult(result map[string]any) bool {
|
|
return result["data"] != nil || valueAtPath(result, "output.image_urls") != nil || valueAtPath(result, "output.video_url") != nil || valueAtPath(result, "Response.ResultVideoUrl") != nil || valueAtPath(result, "Response.ResultImages") != nil || result["urls"] != nil
|
|
}
|
|
|
|
func normalizeProviderTaskResult(request Request, spec providerTaskSpec, result map[string]any, upstreamTaskID string) map[string]any {
|
|
out := cloneMapAny(result)
|
|
out["status"] = "success"
|
|
if upstreamTaskID != "" {
|
|
out["upstream_task_id"] = upstreamTaskID
|
|
}
|
|
if out["created"] == nil {
|
|
out["created"] = time.Now().UnixMilli()
|
|
}
|
|
if out["model"] == nil {
|
|
out["model"] = request.Model
|
|
}
|
|
if _, ok := out["data"].([]any); !ok {
|
|
if out["data"] != nil {
|
|
out["raw_data"] = out["data"]
|
|
}
|
|
out["data"] = providerTaskData(request, result)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func providerTaskData(request Request, result map[string]any) []any {
|
|
fileType := "image"
|
|
if request.Kind == "videos.generations" || strings.Contains(request.ModelType, "video") {
|
|
fileType = "video"
|
|
}
|
|
urlValues := []any{}
|
|
for _, path := range []string{
|
|
"urls",
|
|
"image_urls",
|
|
"data.image_urls",
|
|
"data.images",
|
|
"output.image_urls",
|
|
"output.video_url",
|
|
"output.output",
|
|
"data.output",
|
|
"data.video_url",
|
|
"video_url",
|
|
"preview_url",
|
|
"Response.ResultImages",
|
|
"Response.ResultVideoUrl",
|
|
} {
|
|
appendURLValues(&urlValues, valueAtPath(result, path))
|
|
}
|
|
data := make([]any, 0, len(urlValues))
|
|
for _, raw := range urlValues {
|
|
if url := strings.TrimSpace(fmt.Sprint(raw)); url != "" {
|
|
data = append(data, map[string]any{"type": fileType, "url": url})
|
|
}
|
|
}
|
|
if len(data) == 0 {
|
|
if base64Values := valueAtPath(result, "data.binary_data_base64"); base64Values != nil {
|
|
values := []any{}
|
|
appendURLValues(&values, base64Values)
|
|
for _, raw := range values {
|
|
if content := strings.TrimSpace(fmt.Sprint(raw)); content != "" {
|
|
data = append(data, map[string]any{"type": fileType, "content": content, "uploaded": false})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return data
|
|
}
|
|
|
|
func appendURLValues(out *[]any, value any) {
|
|
switch typed := value.(type) {
|
|
case nil:
|
|
case string:
|
|
*out = append(*out, typed)
|
|
case []any:
|
|
for _, item := range typed {
|
|
appendURLValues(out, item)
|
|
}
|
|
case []string:
|
|
for _, item := range typed {
|
|
*out = append(*out, item)
|
|
}
|
|
case map[string]any:
|
|
for _, key := range []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "content", "output"} {
|
|
if item := strings.TrimSpace(fmt.Sprint(typed[key])); item != "" && item != "<nil>" {
|
|
*out = append(*out, item)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func providerTaskFailure(spec providerTaskSpec, result map[string]any, requestID string, startedAt time.Time) error {
|
|
message := firstNonEmptyString(valueAtPath(result, "message"), valueAtPath(result, "error.message"), valueAtPath(result, "error"), valueAtPath(result, "Response.ErrorMessage"), valueAtPath(result, "comment"), spec.Name+" task failed")
|
|
return &ClientError{
|
|
Code: firstNonEmptyString(valueAtPath(result, "code"), valueAtPath(result, "error_code"), valueAtPath(result, "Response.ErrorCode"), "provider_failed"),
|
|
Message: message,
|
|
RequestID: requestID,
|
|
ResponseStartedAt: startedAt,
|
|
ResponseFinishedAt: time.Now(),
|
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
|
Retryable: false,
|
|
}
|
|
}
|
|
|
|
func providerPollInterval(request Request) time.Duration {
|
|
return durationFromConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
|
|
}
|
|
|
|
func providerPollTimeout(request Request) time.Duration {
|
|
return durationFromConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
|
|
}
|
|
|
|
func durationFromConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
|
|
for _, key := range keys {
|
|
switch value := config[key].(type) {
|
|
case int:
|
|
if value > 0 {
|
|
return time.Duration(value) * time.Millisecond
|
|
}
|
|
case int64:
|
|
if value > 0 {
|
|
return time.Duration(value) * time.Millisecond
|
|
}
|
|
case float64:
|
|
if value > 0 {
|
|
return time.Duration(value) * time.Millisecond
|
|
}
|
|
case string:
|
|
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
|
|
return parsed
|
|
}
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func valueAtPath(values map[string]any, path string) any {
|
|
if values == nil || strings.TrimSpace(path) == "" {
|
|
return nil
|
|
}
|
|
var current any = values
|
|
for _, part := range strings.Split(path, ".") {
|
|
object, ok := current.(map[string]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
current = object[part]
|
|
}
|
|
return current
|
|
}
|
|
|
|
func mediaPromptText(body map[string]any) string {
|
|
if prompt := strings.TrimSpace(stringFromAny(body["prompt"])); prompt != "" {
|
|
return prompt
|
|
}
|
|
content, _ := body["content"].([]any)
|
|
for _, item := range content {
|
|
if part, ok := item.(map[string]any); ok && strings.TrimSpace(stringFromAny(part["type"])) == "text" {
|
|
if text := strings.TrimSpace(stringFromAny(part["text"])); text != "" {
|
|
return text
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|