easyai-ai-gateway/apps/api/internal/clients/universal.go

482 lines
18 KiB
Go

package clients
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
)
type UniversalClient struct {
HTTPClient *http.Client
ScriptExecutor *scriptengine.Executor
}
func (c UniversalClient) Run(ctx context.Context, request Request) (Response, error) {
executor := c.ScriptExecutor
if executor == nil {
executor = &scriptengine.Executor{}
}
startedAt := time.Now()
modelType := strings.TrimSpace(request.ModelType)
if modelType == "" {
modelType = strings.TrimSpace(request.Candidate.ModelType)
}
payload := cloneBody(request.Body)
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
submitRequestID := upstreamTaskID
var submitResult map[string]any
if upstreamTaskID == "" {
var err error
payload, err = c.universalGetParams(ctx, executor, request, modelType)
if err != nil {
return Response{}, err
}
submitResult, submitRequestID, err = c.universalSubmit(ctx, executor, request, modelType, payload)
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, startedAt, time.Now())
}
if isUniversalSuccess(submitResult) && submitResult["data"] != nil {
return Response{
Result: normalizeUniversalResult(request, submitResult, ""),
RequestID: firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)),
Progress: providerProgress(request),
ResponseStartedAt: startedAt,
ResponseFinishedAt: time.Now(),
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
}, nil
}
if isUniversalFailure(submitResult) {
return Response{}, universalFailureError(submitResult, firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), startedAt)
}
upstreamTaskID = universalTaskID(submitResult)
if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: "universal task id is missing", RequestID: submitRequestID, Retryable: false}
}
if request.OnRemoteTaskSubmitted != nil {
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
return Response{}, err
}
}
} else if request.RemoteTaskPayload != nil {
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
payload = existingPayload
}
}
result, requestID, err := c.universalPollUntilDone(ctx, executor, request, modelType, upstreamTaskID, payload, firstNonEmptyString(submitRequestID, upstreamTaskID), startedAt)
if err != nil {
return Response{}, err
}
finishedAt := time.Now()
return Response{
Result: normalizeUniversalResult(request, result, upstreamTaskID),
RequestID: firstNonEmptyString(requestID, submitRequestID, requestIDFromResult(result), upstreamTaskID),
Progress: universalProgress(request, upstreamTaskID),
ResponseStartedAt: startedAt,
ResponseFinishedAt: finishedAt,
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
}, nil
}
func (c UniversalClient) universalGetParams(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string) (map[string]any, error) {
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customGetParamsScript", "custom_get_params_script"); scriptText != "" {
scriptContext := universalScriptContext(request, modelType, nil)
out, err := executor.Execute(ctx, scriptengine.Options{
Script: scriptText,
Args: []any{cloneBody(request.Body), scriptContext},
ContextData: scriptContext,
ScriptName: "custom_get_params_script:" + modelType,
PreferredEntryNames: []string{"getGenerateParams", "getParams", "main", "handler"},
Timeout: 30 * time.Second,
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
})
if err != nil {
return nil, universalScriptError(err)
}
if params, ok := out.(map[string]any); ok && params != nil {
if params["_originalParams"] == nil {
params["_originalParams"] = cloneBody(request.Body)
}
return params, nil
}
return nil, &ClientError{Code: "invalid_response", Message: "custom get params script must return an object", Retryable: false}
}
body := universalDefaultPayload(request)
body["_originalParams"] = cloneBody(request.Body)
return body, nil
}
func (c UniversalClient) universalSubmit(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, payload map[string]any) (map[string]any, string, error) {
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customSubmitScript", "custom_submit_script"); scriptText != "" {
scriptContext := universalScriptContext(request, modelType, payload)
out, err := executor.Execute(ctx, scriptengine.Options{
Script: scriptText,
Args: []any{cloneBody(payload), scriptContext},
ContextData: scriptContext,
ScriptName: "custom_submit_script:" + modelType,
PreferredEntryNames: []string{"submitTask", "submitParams", "submit", "main", "handler"},
Timeout: 30 * time.Second,
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
})
if err != nil {
return nil, "", universalScriptError(err)
}
result, ok := out.(map[string]any)
if !ok || result == nil {
return nil, "", &ClientError{Code: "invalid_response", Message: "custom submit script must return an object", Retryable: false}
}
return result, requestIDFromResult(result), nil
}
endpoint := universalSubmitEndpoint(request)
result, requestID, err := universalPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), request.Candidate.BaseURL, endpoint, universalStripPrivatePayload(payload), request.Candidate.Credentials)
return result, requestID, err
}
func (c UniversalClient) universalPollUntilDone(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any, requestID string, startedAt time.Time) (map[string]any, string, error) {
interval := universalDurationConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
timeout := universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
deadline := time.NewTimer(timeout)
defer deadline.Stop()
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastResult map[string]any
for {
pollStarted := time.Now()
result, pollRequestID, err := c.universalPoll(ctx, executor, request, modelType, upstreamTaskID, payload)
pollFinished := time.Now()
if err != nil {
return nil, "", annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
}
lastResult = result
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
if isUniversalSuccess(result) {
return result, requestID, nil
}
if isUniversalFailure(result) {
return nil, "", universalFailureError(result, requestID, startedAt)
}
select {
case <-ctx.Done():
return nil, "", &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
case <-deadline.C:
return nil, "", &ClientError{Code: "timeout", Message: fmt.Sprintf("universal task %s did not finish before timeout; last status: %s", upstreamTaskID, universalStatus(lastResult)), RequestID: requestID, Retryable: true}
case <-ticker.C:
}
}
}
func (c UniversalClient) universalPoll(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customPollScript", "custom_poll_script"); scriptText != "" {
scriptContext := universalScriptContext(request, modelType, payload)
out, err := executor.Execute(ctx, scriptengine.Options{
Script: scriptText,
Args: []any{upstreamTaskID, scriptContext},
ContextData: scriptContext,
ScriptName: "custom_poll_script:" + modelType,
PreferredEntryNames: []string{"pollTask", "poll", "main", "handler"},
Timeout: 30 * time.Second,
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
})
if err != nil {
return nil, "", universalScriptError(err)
}
result, ok := out.(map[string]any)
if !ok || result == nil {
return nil, "", &ClientError{Code: "invalid_response", Message: "custom poll script must return an object", Retryable: false}
}
return result, requestIDFromResult(result), nil
}
pollURL := resolveUniversalTaskURL(request.Candidate.PlatformConfig, upstreamTaskID)
if pollURL == "" {
return nil, "", &ClientError{Code: "missing_configuration", Message: "universal getTaskURL is required", Retryable: false}
}
return universalGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), pollURL, request.Candidate.Credentials)
}
func universalScriptContext(request Request, modelType string, payload map[string]any) map[string]any {
baseURL := strings.TrimRight(strings.TrimSpace(request.Candidate.BaseURL), "/")
getTaskURL := universalConfigString(request.Candidate.PlatformConfig, "getTaskURL", "get_task_url")
context := map[string]any{
"__easyaiScriptContext": true,
"baseURL": baseURL,
"getTaskURL": getTaskURL,
"authValues": cloneMapAny(request.Candidate.Credentials),
"headers": map[string]any{},
"payload": cloneMapAny(payload),
"type": modelType,
"options": map[string]any{
"task_id": request.RemoteTaskID,
"upstream_task_id": request.RemoteTaskID,
"model": request.Model,
"providerModelName": request.Candidate.ProviderModelName,
"platformId": request.Candidate.PlatformID,
"platformModelId": request.Candidate.PlatformModelID,
"canonicalModelKey": request.Candidate.CanonicalModelKey,
"modelType": modelType,
"timeout": universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms").Milliseconds(),
},
"env": cloneMapAny(request.Candidate.PlatformConfig),
"candidate": universalCandidateSnapshot(request),
}
context["createRequestURL"] = func(path string, base ...string) string {
selectedBase := baseURL
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
}
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
return selectedBase + "/" + strings.TrimLeft(path, "/")
}
context["creatRequestURL"] = context["createRequestURL"]
context["resolveGetTaskURL"] = func(taskID string) string {
return resolveUniversalTaskURL(request.Candidate.PlatformConfig, taskID)
}
return context
}
func universalCandidateSnapshot(request Request) map[string]any {
return map[string]any{
"modelName": request.Candidate.ModelName,
"modelAlias": request.Candidate.ModelAlias,
"providerModelName": request.Candidate.ProviderModelName,
"provider": request.Candidate.Provider,
"platformId": request.Candidate.PlatformID,
"platformModelId": request.Candidate.PlatformModelID,
"capabilities": cloneMapAny(request.Candidate.Capabilities),
}
}
func universalDefaultPayload(request Request) map[string]any {
body := cloneBody(request.Body)
body["model"] = upstreamModelName(request.Candidate)
if request.Kind == "images.generations" {
if n := firstPresent(body["n"], body["numImages"]); n != nil {
body["numImages"] = n
}
if aspectRatio := strings.TrimSpace(stringFromAny(body["aspect_ratio"])); aspectRatio != "" {
body["aspectRatio"] = aspectRatio
}
}
return body
}
func universalSubmitEndpoint(request Request) string {
if endpoint := universalConfigString(request.Candidate.PlatformConfig, "submitPath", "submit_path"); endpoint != "" {
return endpoint
}
switch request.Kind {
case "images.generations":
return "/images/generations"
case "images.edits":
return "/images/edits"
case "videos.generations":
return "/video/generations"
default:
return "/" + strings.ReplaceAll(request.Kind, ".", "/")
}
}
func universalPostJSON(ctx context.Context, client *http.Client, baseURL string, endpoint string, body map[string]any, credentials map[string]any) (map[string]any, string, error) {
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, providerURL(baseURL, endpoint), bytes.NewReader(raw))
if err != nil {
return nil, "", err
}
req.Header.Set("Content-Type", "application/json")
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
return result, requestID, err
}
func universalGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any) (map[string]any, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, "", err
}
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
resp, err := client.Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
return result, requestID, err
}
func normalizeUniversalResult(request Request, result map[string]any, upstreamTaskID string) map[string]any {
out := cloneMapAny(result)
if out["created"] == nil {
out["created"] = time.Now().UnixMilli()
}
if out["task_id"] == nil {
out["task_id"] = upstreamTaskID
}
if out["upstream_task_id"] == nil {
out["upstream_task_id"] = upstreamTaskID
}
if out["model"] == nil {
out["model"] = request.Model
}
if out["status"] == nil {
out["status"] = "success"
}
return out
}
func universalScriptError(err error) error {
var scriptErr *scriptengine.Error
if strings.TrimSpace(err.Error()) == "" {
return &ClientError{Code: "script_error", Message: "script execution failed", Retryable: false}
}
if errors.As(err, &scriptErr) {
return &ClientError{Code: scriptErr.ErrorCode(), Message: scriptErr.Error(), Retryable: scriptErr.ErrorCode() == "script_timeout"}
}
return &ClientError{Code: "script_error", Message: err.Error(), Retryable: false}
}
func universalFailureError(result map[string]any, requestID string, startedAt time.Time) error {
message := firstNonEmptyString(result["message"], result["error"], result["error_message"], "universal task failed")
return &ClientError{
Code: firstNonEmptyString(result["code"], result["error_code"], "provider_failed"),
Message: message,
RequestID: requestID,
ResponseStartedAt: startedAt,
ResponseFinishedAt: time.Now(),
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
Retryable: false,
}
}
func isUniversalSuccess(result map[string]any) bool {
switch universalStatus(result) {
case "success", "succeeded", "completed", "complete", "done":
return true
default:
return false
}
}
func isUniversalFailure(result map[string]any) bool {
switch universalStatus(result) {
case "failed", "failure", "error", "cancelled", "canceled":
return true
default:
return false
}
}
func universalStatus(result map[string]any) string {
return strings.ToLower(strings.TrimSpace(firstNonEmptyString(result["status"], result["state"], result["task_status"])))
}
func universalTaskID(result map[string]any) string {
return firstNonEmptyString(result["upstream_task_id"], result["task_id"], result["taskId"], result["id"])
}
func universalProgress(request Request, upstreamTaskID string) []Progress {
progress := providerProgress(request)
progress = append(progress, Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}})
return progress
}
func universalStripPrivatePayload(payload map[string]any) map[string]any {
out := cloneMapAny(payload)
for _, key := range []string{"_originalParams", "_resolution", "_duration"} {
delete(out, key)
}
return out
}
func universalSceneScript(config map[string]any, modelType string, keys ...string) string {
for _, key := range keys {
value := config[key]
switch typed := value.(type) {
case string:
if strings.TrimSpace(typed) != "" {
return strings.TrimSpace(typed)
}
case map[string]any:
if script := firstNonEmptyString(typed[modelType], typed["common"]); script != "" {
return script
}
}
}
return ""
}
func universalConfigString(config map[string]any, keys ...string) string {
for _, key := range keys {
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
return value
}
}
return ""
}
func universalDurationConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
for _, key := range keys {
switch value := config[key].(type) {
case int:
if value > 0 {
return time.Duration(value) * time.Millisecond
}
case int64:
if value > 0 {
return time.Duration(value) * time.Millisecond
}
case float64:
if value > 0 {
return time.Duration(value) * time.Millisecond
}
case string:
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
return parsed
}
}
}
return fallback
}
func resolveUniversalTaskURL(config map[string]any, upstreamTaskID string) string {
template := universalConfigString(config, "getTaskURL", "get_task_url")
out := strings.TrimSpace(template)
replacements := [][2]string{
{"${upstream_task_id}", upstreamTaskID},
{"{{upstream_task_id}}", upstreamTaskID},
{"{upstream_task_id}", upstreamTaskID},
{"${task_id}", upstreamTaskID},
{"{{task_id}}", upstreamTaskID},
{"{task_id}", upstreamTaskID},
{"${taskId}", upstreamTaskID},
{"${taskID}", upstreamTaskID},
{"{{taskId}}", upstreamTaskID},
{"{{taskID}}", upstreamTaskID},
{"{taskId}", upstreamTaskID},
{"{taskID}", upstreamTaskID},
}
for _, replacement := range replacements {
out = strings.ReplaceAll(out, replacement[0], replacement[1])
}
return out
}