482 lines
18 KiB
Go
482 lines
18 KiB
Go
package clients
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
|
)
|
|
|
|
type UniversalClient struct {
|
|
HTTPClient *http.Client
|
|
ScriptExecutor *scriptengine.Executor
|
|
}
|
|
|
|
func (c UniversalClient) Run(ctx context.Context, request Request) (Response, error) {
|
|
executor := c.ScriptExecutor
|
|
if executor == nil {
|
|
executor = &scriptengine.Executor{}
|
|
}
|
|
startedAt := time.Now()
|
|
modelType := strings.TrimSpace(request.ModelType)
|
|
if modelType == "" {
|
|
modelType = strings.TrimSpace(request.Candidate.ModelType)
|
|
}
|
|
|
|
payload := cloneBody(request.Body)
|
|
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
|
submitRequestID := upstreamTaskID
|
|
var submitResult map[string]any
|
|
|
|
if upstreamTaskID == "" {
|
|
var err error
|
|
payload, err = c.universalGetParams(ctx, executor, request, modelType)
|
|
if err != nil {
|
|
return Response{}, err
|
|
}
|
|
submitResult, submitRequestID, err = c.universalSubmit(ctx, executor, request, modelType, payload)
|
|
if err != nil {
|
|
return Response{}, annotateResponseError(err, submitRequestID, startedAt, time.Now())
|
|
}
|
|
if isUniversalSuccess(submitResult) && submitResult["data"] != nil {
|
|
return Response{
|
|
Result: normalizeUniversalResult(request, submitResult, ""),
|
|
RequestID: firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)),
|
|
Progress: providerProgress(request),
|
|
ResponseStartedAt: startedAt,
|
|
ResponseFinishedAt: time.Now(),
|
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
|
}, nil
|
|
}
|
|
if isUniversalFailure(submitResult) {
|
|
return Response{}, universalFailureError(submitResult, firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), startedAt)
|
|
}
|
|
upstreamTaskID = universalTaskID(submitResult)
|
|
if upstreamTaskID == "" {
|
|
return Response{}, &ClientError{Code: "invalid_response", Message: "universal task id is missing", RequestID: submitRequestID, Retryable: false}
|
|
}
|
|
if request.OnRemoteTaskSubmitted != nil {
|
|
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
|
|
return Response{}, err
|
|
}
|
|
}
|
|
} else if request.RemoteTaskPayload != nil {
|
|
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
|
|
payload = existingPayload
|
|
}
|
|
}
|
|
|
|
result, requestID, err := c.universalPollUntilDone(ctx, executor, request, modelType, upstreamTaskID, payload, firstNonEmptyString(submitRequestID, upstreamTaskID), startedAt)
|
|
if err != nil {
|
|
return Response{}, err
|
|
}
|
|
finishedAt := time.Now()
|
|
return Response{
|
|
Result: normalizeUniversalResult(request, result, upstreamTaskID),
|
|
RequestID: firstNonEmptyString(requestID, submitRequestID, requestIDFromResult(result), upstreamTaskID),
|
|
Progress: universalProgress(request, upstreamTaskID),
|
|
ResponseStartedAt: startedAt,
|
|
ResponseFinishedAt: finishedAt,
|
|
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
|
|
}, nil
|
|
}
|
|
|
|
func (c UniversalClient) universalGetParams(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string) (map[string]any, error) {
|
|
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customGetParamsScript", "custom_get_params_script"); scriptText != "" {
|
|
scriptContext := universalScriptContext(request, modelType, nil)
|
|
out, err := executor.Execute(ctx, scriptengine.Options{
|
|
Script: scriptText,
|
|
Args: []any{cloneBody(request.Body), scriptContext},
|
|
ContextData: scriptContext,
|
|
ScriptName: "custom_get_params_script:" + modelType,
|
|
PreferredEntryNames: []string{"getGenerateParams", "getParams", "main", "handler"},
|
|
Timeout: 30 * time.Second,
|
|
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
|
})
|
|
if err != nil {
|
|
return nil, universalScriptError(err)
|
|
}
|
|
if params, ok := out.(map[string]any); ok && params != nil {
|
|
if params["_originalParams"] == nil {
|
|
params["_originalParams"] = cloneBody(request.Body)
|
|
}
|
|
return params, nil
|
|
}
|
|
return nil, &ClientError{Code: "invalid_response", Message: "custom get params script must return an object", Retryable: false}
|
|
}
|
|
body := universalDefaultPayload(request)
|
|
body["_originalParams"] = cloneBody(request.Body)
|
|
return body, nil
|
|
}
|
|
|
|
func (c UniversalClient) universalSubmit(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, payload map[string]any) (map[string]any, string, error) {
|
|
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customSubmitScript", "custom_submit_script"); scriptText != "" {
|
|
scriptContext := universalScriptContext(request, modelType, payload)
|
|
out, err := executor.Execute(ctx, scriptengine.Options{
|
|
Script: scriptText,
|
|
Args: []any{cloneBody(payload), scriptContext},
|
|
ContextData: scriptContext,
|
|
ScriptName: "custom_submit_script:" + modelType,
|
|
PreferredEntryNames: []string{"submitTask", "submitParams", "submit", "main", "handler"},
|
|
Timeout: 30 * time.Second,
|
|
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
|
})
|
|
if err != nil {
|
|
return nil, "", universalScriptError(err)
|
|
}
|
|
result, ok := out.(map[string]any)
|
|
if !ok || result == nil {
|
|
return nil, "", &ClientError{Code: "invalid_response", Message: "custom submit script must return an object", Retryable: false}
|
|
}
|
|
return result, requestIDFromResult(result), nil
|
|
}
|
|
endpoint := universalSubmitEndpoint(request)
|
|
result, requestID, err := universalPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), request.Candidate.BaseURL, endpoint, universalStripPrivatePayload(payload), request.Candidate.Credentials)
|
|
return result, requestID, err
|
|
}
|
|
|
|
func (c UniversalClient) universalPollUntilDone(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any, requestID string, startedAt time.Time) (map[string]any, string, error) {
|
|
interval := universalDurationConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
|
|
timeout := universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
|
|
deadline := time.NewTimer(timeout)
|
|
defer deadline.Stop()
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
var lastResult map[string]any
|
|
for {
|
|
pollStarted := time.Now()
|
|
result, pollRequestID, err := c.universalPoll(ctx, executor, request, modelType, upstreamTaskID, payload)
|
|
pollFinished := time.Now()
|
|
if err != nil {
|
|
return nil, "", annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
|
|
}
|
|
lastResult = result
|
|
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
|
|
if isUniversalSuccess(result) {
|
|
return result, requestID, nil
|
|
}
|
|
if isUniversalFailure(result) {
|
|
return nil, "", universalFailureError(result, requestID, startedAt)
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, "", &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
|
case <-deadline.C:
|
|
return nil, "", &ClientError{Code: "timeout", Message: fmt.Sprintf("universal task %s did not finish before timeout; last status: %s", upstreamTaskID, universalStatus(lastResult)), RequestID: requestID, Retryable: true}
|
|
case <-ticker.C:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c UniversalClient) universalPoll(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
|
|
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customPollScript", "custom_poll_script"); scriptText != "" {
|
|
scriptContext := universalScriptContext(request, modelType, payload)
|
|
out, err := executor.Execute(ctx, scriptengine.Options{
|
|
Script: scriptText,
|
|
Args: []any{upstreamTaskID, scriptContext},
|
|
ContextData: scriptContext,
|
|
ScriptName: "custom_poll_script:" + modelType,
|
|
PreferredEntryNames: []string{"pollTask", "poll", "main", "handler"},
|
|
Timeout: 30 * time.Second,
|
|
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
|
})
|
|
if err != nil {
|
|
return nil, "", universalScriptError(err)
|
|
}
|
|
result, ok := out.(map[string]any)
|
|
if !ok || result == nil {
|
|
return nil, "", &ClientError{Code: "invalid_response", Message: "custom poll script must return an object", Retryable: false}
|
|
}
|
|
return result, requestIDFromResult(result), nil
|
|
}
|
|
pollURL := resolveUniversalTaskURL(request.Candidate.PlatformConfig, upstreamTaskID)
|
|
if pollURL == "" {
|
|
return nil, "", &ClientError{Code: "missing_configuration", Message: "universal getTaskURL is required", Retryable: false}
|
|
}
|
|
return universalGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), pollURL, request.Candidate.Credentials)
|
|
}
|
|
|
|
func universalScriptContext(request Request, modelType string, payload map[string]any) map[string]any {
|
|
baseURL := strings.TrimRight(strings.TrimSpace(request.Candidate.BaseURL), "/")
|
|
getTaskURL := universalConfigString(request.Candidate.PlatformConfig, "getTaskURL", "get_task_url")
|
|
context := map[string]any{
|
|
"__easyaiScriptContext": true,
|
|
"baseURL": baseURL,
|
|
"getTaskURL": getTaskURL,
|
|
"authValues": cloneMapAny(request.Candidate.Credentials),
|
|
"headers": map[string]any{},
|
|
"payload": cloneMapAny(payload),
|
|
"type": modelType,
|
|
"options": map[string]any{
|
|
"task_id": request.RemoteTaskID,
|
|
"upstream_task_id": request.RemoteTaskID,
|
|
"model": request.Model,
|
|
"providerModelName": request.Candidate.ProviderModelName,
|
|
"platformId": request.Candidate.PlatformID,
|
|
"platformModelId": request.Candidate.PlatformModelID,
|
|
"canonicalModelKey": request.Candidate.CanonicalModelKey,
|
|
"modelType": modelType,
|
|
"timeout": universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms").Milliseconds(),
|
|
},
|
|
"env": cloneMapAny(request.Candidate.PlatformConfig),
|
|
"candidate": universalCandidateSnapshot(request),
|
|
}
|
|
context["createRequestURL"] = func(path string, base ...string) string {
|
|
selectedBase := baseURL
|
|
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
|
|
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
|
|
}
|
|
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
|
return path
|
|
}
|
|
return selectedBase + "/" + strings.TrimLeft(path, "/")
|
|
}
|
|
context["creatRequestURL"] = context["createRequestURL"]
|
|
context["resolveGetTaskURL"] = func(taskID string) string {
|
|
return resolveUniversalTaskURL(request.Candidate.PlatformConfig, taskID)
|
|
}
|
|
return context
|
|
}
|
|
|
|
func universalCandidateSnapshot(request Request) map[string]any {
|
|
return map[string]any{
|
|
"modelName": request.Candidate.ModelName,
|
|
"modelAlias": request.Candidate.ModelAlias,
|
|
"providerModelName": request.Candidate.ProviderModelName,
|
|
"provider": request.Candidate.Provider,
|
|
"platformId": request.Candidate.PlatformID,
|
|
"platformModelId": request.Candidate.PlatformModelID,
|
|
"capabilities": cloneMapAny(request.Candidate.Capabilities),
|
|
}
|
|
}
|
|
|
|
func universalDefaultPayload(request Request) map[string]any {
|
|
body := cloneBody(request.Body)
|
|
body["model"] = upstreamModelName(request.Candidate)
|
|
if request.Kind == "images.generations" {
|
|
if n := firstPresent(body["n"], body["numImages"]); n != nil {
|
|
body["numImages"] = n
|
|
}
|
|
if aspectRatio := strings.TrimSpace(stringFromAny(body["aspect_ratio"])); aspectRatio != "" {
|
|
body["aspectRatio"] = aspectRatio
|
|
}
|
|
}
|
|
return body
|
|
}
|
|
|
|
func universalSubmitEndpoint(request Request) string {
|
|
if endpoint := universalConfigString(request.Candidate.PlatformConfig, "submitPath", "submit_path"); endpoint != "" {
|
|
return endpoint
|
|
}
|
|
switch request.Kind {
|
|
case "images.generations":
|
|
return "/images/generations"
|
|
case "images.edits":
|
|
return "/images/edits"
|
|
case "videos.generations":
|
|
return "/video/generations"
|
|
default:
|
|
return "/" + strings.ReplaceAll(request.Kind, ".", "/")
|
|
}
|
|
}
|
|
|
|
func universalPostJSON(ctx context.Context, client *http.Client, baseURL string, endpoint string, body map[string]any, credentials map[string]any) (map[string]any, string, error) {
|
|
raw, _ := json.Marshal(body)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, providerURL(baseURL, endpoint), bytes.NewReader(raw))
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
|
}
|
|
requestID := requestIDFromHTTPResponse(resp)
|
|
result, err := decodeHTTPResponse(resp)
|
|
return result, requestID, err
|
|
}
|
|
|
|
func universalGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any) (map[string]any, string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
|
}
|
|
requestID := requestIDFromHTTPResponse(resp)
|
|
result, err := decodeHTTPResponse(resp)
|
|
return result, requestID, err
|
|
}
|
|
|
|
func normalizeUniversalResult(request Request, result map[string]any, upstreamTaskID string) map[string]any {
|
|
out := cloneMapAny(result)
|
|
if out["created"] == nil {
|
|
out["created"] = time.Now().UnixMilli()
|
|
}
|
|
if out["task_id"] == nil {
|
|
out["task_id"] = upstreamTaskID
|
|
}
|
|
if out["upstream_task_id"] == nil {
|
|
out["upstream_task_id"] = upstreamTaskID
|
|
}
|
|
if out["model"] == nil {
|
|
out["model"] = request.Model
|
|
}
|
|
if out["status"] == nil {
|
|
out["status"] = "success"
|
|
}
|
|
return out
|
|
}
|
|
|
|
func universalScriptError(err error) error {
|
|
var scriptErr *scriptengine.Error
|
|
if strings.TrimSpace(err.Error()) == "" {
|
|
return &ClientError{Code: "script_error", Message: "script execution failed", Retryable: false}
|
|
}
|
|
if errors.As(err, &scriptErr) {
|
|
return &ClientError{Code: scriptErr.ErrorCode(), Message: scriptErr.Error(), Retryable: scriptErr.ErrorCode() == "script_timeout"}
|
|
}
|
|
return &ClientError{Code: "script_error", Message: err.Error(), Retryable: false}
|
|
}
|
|
|
|
func universalFailureError(result map[string]any, requestID string, startedAt time.Time) error {
|
|
message := firstNonEmptyString(result["message"], result["error"], result["error_message"], "universal task failed")
|
|
return &ClientError{
|
|
Code: firstNonEmptyString(result["code"], result["error_code"], "provider_failed"),
|
|
Message: message,
|
|
RequestID: requestID,
|
|
ResponseStartedAt: startedAt,
|
|
ResponseFinishedAt: time.Now(),
|
|
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
|
Retryable: false,
|
|
}
|
|
}
|
|
|
|
func isUniversalSuccess(result map[string]any) bool {
|
|
switch universalStatus(result) {
|
|
case "success", "succeeded", "completed", "complete", "done":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isUniversalFailure(result map[string]any) bool {
|
|
switch universalStatus(result) {
|
|
case "failed", "failure", "error", "cancelled", "canceled":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func universalStatus(result map[string]any) string {
|
|
return strings.ToLower(strings.TrimSpace(firstNonEmptyString(result["status"], result["state"], result["task_status"])))
|
|
}
|
|
|
|
func universalTaskID(result map[string]any) string {
|
|
return firstNonEmptyString(result["upstream_task_id"], result["task_id"], result["taskId"], result["id"])
|
|
}
|
|
|
|
func universalProgress(request Request, upstreamTaskID string) []Progress {
|
|
progress := providerProgress(request)
|
|
progress = append(progress, Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}})
|
|
return progress
|
|
}
|
|
|
|
func universalStripPrivatePayload(payload map[string]any) map[string]any {
|
|
out := cloneMapAny(payload)
|
|
for _, key := range []string{"_originalParams", "_resolution", "_duration"} {
|
|
delete(out, key)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func universalSceneScript(config map[string]any, modelType string, keys ...string) string {
|
|
for _, key := range keys {
|
|
value := config[key]
|
|
switch typed := value.(type) {
|
|
case string:
|
|
if strings.TrimSpace(typed) != "" {
|
|
return strings.TrimSpace(typed)
|
|
}
|
|
case map[string]any:
|
|
if script := firstNonEmptyString(typed[modelType], typed["common"]); script != "" {
|
|
return script
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func universalConfigString(config map[string]any, keys ...string) string {
|
|
for _, key := range keys {
|
|
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
|
|
return value
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func universalDurationConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
|
|
for _, key := range keys {
|
|
switch value := config[key].(type) {
|
|
case int:
|
|
if value > 0 {
|
|
return time.Duration(value) * time.Millisecond
|
|
}
|
|
case int64:
|
|
if value > 0 {
|
|
return time.Duration(value) * time.Millisecond
|
|
}
|
|
case float64:
|
|
if value > 0 {
|
|
return time.Duration(value) * time.Millisecond
|
|
}
|
|
case string:
|
|
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
|
|
return parsed
|
|
}
|
|
}
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func resolveUniversalTaskURL(config map[string]any, upstreamTaskID string) string {
|
|
template := universalConfigString(config, "getTaskURL", "get_task_url")
|
|
out := strings.TrimSpace(template)
|
|
replacements := [][2]string{
|
|
{"${upstream_task_id}", upstreamTaskID},
|
|
{"{{upstream_task_id}}", upstreamTaskID},
|
|
{"{upstream_task_id}", upstreamTaskID},
|
|
{"${task_id}", upstreamTaskID},
|
|
{"{{task_id}}", upstreamTaskID},
|
|
{"{task_id}", upstreamTaskID},
|
|
{"${taskId}", upstreamTaskID},
|
|
{"${taskID}", upstreamTaskID},
|
|
{"{{taskId}}", upstreamTaskID},
|
|
{"{{taskID}}", upstreamTaskID},
|
|
{"{taskId}", upstreamTaskID},
|
|
{"{taskID}", upstreamTaskID},
|
|
}
|
|
for _, replacement := range replacements {
|
|
out = strings.ReplaceAll(out, replacement[0], replacement[1])
|
|
}
|
|
return out
|
|
}
|