feat(api): migrate media clients and universal scripts
This commit is contained in:
parent
11a2c13e4a
commit
af9b281d34
@ -13,6 +13,11 @@ require (
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||
github.com/dop251/goja v0.0.0-20260311135729-065cd970411c // indirect
|
||||
github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14 // indirect
|
||||
github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect
|
||||
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
|
||||
@ -1,8 +1,18 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dop251/goja v0.0.0-20260311135729-065cd970411c h1:OcLmPfx1T1RmZVHHFwWMPaZDdRf0DBMZOFMVWJa7Pdk=
|
||||
github.com/dop251/goja v0.0.0-20260311135729-065cd970411c/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
|
||||
github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14 h1:3U8dTgyNBhEQ/GVw0jZW5q+93Zw2gAZPRWhJ9TwV3rM=
|
||||
github.com/dop251/goja_nodejs v0.0.0-20260212111938-1f56ff5bcf14/go.mod h1:Tb7Xxye4LX7cT3i8YLvmPMGCV92IOi4CDZvm/V8ylc0=
|
||||
github.com/go-sourcemap/sourcemap v2.1.4+incompatible h1:a+iTbH5auLKxaNwQFg0B+TCYl6lbukKPc7b5x0n1s6Q=
|
||||
github.com/go-sourcemap/sourcemap v2.1.4+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k=
|
||||
github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
|
||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
|
||||
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
|
||||
232
apps/api/internal/clients/media_clients.go
Normal file
232
apps/api/internal/clients/media_clients.go
Normal file
@ -0,0 +1,232 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JimengClient struct{ HTTPClient *http.Client }
|
||||
type BlackforestClient struct{ HTTPClient *http.Client }
|
||||
type HunyuanImageClient struct{ HTTPClient *http.Client }
|
||||
type HunyuanVideoClient struct{ HTTPClient *http.Client }
|
||||
type MinimaxClient struct{ HTTPClient *http.Client }
|
||||
type MidjourneyClient struct{ HTTPClient *http.Client }
|
||||
type ViduClient struct{ HTTPClient *http.Client }
|
||||
type AliyunBailianClient struct{ HTTPClient *http.Client }
|
||||
type NewAPIClient struct{ HTTPClient *http.Client }
|
||||
|
||||
func (c JimengClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: jimengSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c BlackforestClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: blackforestSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c HunyuanImageClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanImageSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c HunyuanVideoClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanVideoSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c MinimaxClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: minimaxSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c MidjourneyClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: midjourneySpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c ViduClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: viduSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c AliyunBailianClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: aliyunBailianSpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func (c NewAPIClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: newAPISpec()}.Run(ctx, request)
|
||||
}
|
||||
|
||||
func jimengSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "jimeng",
|
||||
SubmitPath: func(request Request, _ map[string]any) string {
|
||||
return configuredPath(request, "?Action=CVSubmitTask&Version=2022-08-31", "submitPath", "submit_path")
|
||||
},
|
||||
PollPath: func(request Request, _ string, _ map[string]any) string {
|
||||
return configuredPath(request, "?Action=CVSync2AsyncGetResult&Version=2022-08-31", "pollPath", "poll_path")
|
||||
},
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"data.task_id"},
|
||||
StatusPaths: []string{"data.status"},
|
||||
SuccessStatuses: []string{"done"},
|
||||
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||
body["req_key"] = upstreamModelName(request.Candidate)
|
||||
if body["prompt"] == nil {
|
||||
body["prompt"] = mediaPromptText(body)
|
||||
}
|
||||
return body
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func blackforestSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "blackforest",
|
||||
SubmitPath: func(request Request, body map[string]any) string {
|
||||
return configuredPath(request, "/"+upstreamModelName(request.Candidate), "submitPath", "submit_path")
|
||||
},
|
||||
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return upstreamTaskID },
|
||||
Auth: "x-key",
|
||||
TaskIDPaths: []string{"polling_url"},
|
||||
StatusPaths: []string{"status"},
|
||||
SuccessStatuses: []string{"ready"},
|
||||
FailureStatuses: []string{"error", "task not found"},
|
||||
}
|
||||
}
|
||||
|
||||
func hunyuanImageSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "tencent-hunyuan-image",
|
||||
SubmitPath: func(request Request, _ map[string]any) string {
|
||||
return configuredPath(request, "?Action=SubmitHunyuanImageJob&Version=2023-09-01", "submitPath", "submit_path")
|
||||
},
|
||||
PollPath: func(request Request, _ string, _ map[string]any) string {
|
||||
return configuredPath(request, "?Action=QueryHunyuanImageJob&Version=2023-09-01&JobId=${taskId}", "pollPath", "poll_path")
|
||||
},
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"Response.JobId"},
|
||||
StatusPaths: []string{"Response.Status"},
|
||||
SuccessStatuses: []string{"done"},
|
||||
FailureStatuses: []string{"fail"},
|
||||
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||
body["Prompt"] = mediaPromptText(body)
|
||||
body["Model"] = upstreamModelName(request.Candidate)
|
||||
return body
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func hunyuanVideoSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "tencent-hunyuan-video",
|
||||
SubmitPath: func(request Request, _ map[string]any) string {
|
||||
return configuredPath(request, "?Action=SubmitTextToVideoJob&Version=2024-01-01", "submitPath", "submit_path")
|
||||
},
|
||||
PollPath: func(request Request, _ string, _ map[string]any) string {
|
||||
return configuredPath(request, "?Action=QueryVideoJob&Version=2024-01-01&JobId=${taskId}", "pollPath", "poll_path")
|
||||
},
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"Response.JobId"},
|
||||
StatusPaths: []string{"Response.Status"},
|
||||
SuccessStatuses: []string{"done"},
|
||||
FailureStatuses: []string{"fail"},
|
||||
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||
body["Prompt"] = mediaPromptText(body)
|
||||
body["Model"] = upstreamModelName(request.Candidate)
|
||||
return body
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func minimaxSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "minimax",
|
||||
SubmitPath: func(Request, map[string]any) string { return "/video_generation" },
|
||||
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
|
||||
return "/query/video_generation?task_id=" + upstreamTaskID
|
||||
},
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"task_id"},
|
||||
StatusPaths: []string{"status"},
|
||||
SuccessStatuses: []string{"success"},
|
||||
FailureStatuses: []string{"failed", "expired"},
|
||||
}
|
||||
}
|
||||
|
||||
func midjourneySpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "midjourney",
|
||||
SubmitPath: func(request Request, body map[string]any) string {
|
||||
return configuredPath(request, "/diffusion", "submitPath", "submit_path")
|
||||
},
|
||||
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/job/" + upstreamTaskID },
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"job_id", "id"},
|
||||
StatusPaths: []string{"status"},
|
||||
SuccessStatuses: []string{"success", "completed"},
|
||||
FailureStatuses: []string{"failed"},
|
||||
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
|
||||
if body["prompt"] == nil && body["text"] == nil {
|
||||
body["prompt"] = mediaPromptText(body)
|
||||
}
|
||||
return body
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func viduSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "vidu",
|
||||
SubmitPath: func(request Request, body map[string]any) string {
|
||||
if path := configuredPath(request, "", "submitPath", "submit_path"); path != "" {
|
||||
return path
|
||||
}
|
||||
taskType := firstNonEmptyString(body["type"], body["task_type"], "text2video")
|
||||
if taskType == "multiframe" {
|
||||
return "/multiframe"
|
||||
}
|
||||
return "/" + taskType
|
||||
},
|
||||
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
|
||||
return "/tasks/" + upstreamTaskID + "/creations"
|
||||
},
|
||||
Auth: "token",
|
||||
TaskIDPaths: []string{"task_id"},
|
||||
StatusPaths: []string{"state", "status"},
|
||||
SuccessStatuses: []string{"success", "succeeded"},
|
||||
FailureStatuses: []string{"failed"},
|
||||
}
|
||||
}
|
||||
|
||||
func aliyunBailianSpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "aliyun-bailian",
|
||||
SubmitPath: func(Request, map[string]any) string { return "/services/aigc/video-generation/video-synthesis" },
|
||||
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/tasks/" + upstreamTaskID },
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"output.task_id"},
|
||||
StatusPaths: []string{"output.task_status"},
|
||||
SuccessStatuses: []string{"succeeded", "success"},
|
||||
FailureStatuses: []string{"failed"},
|
||||
}
|
||||
}
|
||||
|
||||
func newAPISpec() providerTaskSpec {
|
||||
return providerTaskSpec{
|
||||
Name: "newapi",
|
||||
SubmitPath: func(Request, map[string]any) string { return "/videos/generations" },
|
||||
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
|
||||
return "/videos/generations/" + upstreamTaskID
|
||||
},
|
||||
Auth: "bearer",
|
||||
TaskIDPaths: []string{"task_id"},
|
||||
StatusPaths: []string{"status"},
|
||||
SuccessStatuses: []string{"success"},
|
||||
FailureStatuses: []string{"failure", "failed"},
|
||||
}
|
||||
}
|
||||
|
||||
func configuredPath(request Request, fallback string, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value := strings.TrimSpace(stringFromAny(request.Candidate.PlatformConfig[key])); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
453
apps/api/internal/clients/provider_task.go
Normal file
453
apps/api/internal/clients/provider_task.go
Normal file
@ -0,0 +1,453 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type providerTaskSpec struct {
|
||||
Name string
|
||||
SubmitPath func(Request, map[string]any) string
|
||||
PollPath func(Request, string, map[string]any) string
|
||||
Auth string
|
||||
TaskIDPaths []string
|
||||
StatusPaths []string
|
||||
SuccessStatuses []string
|
||||
FailureStatuses []string
|
||||
ProcessStatuses []string
|
||||
DefaultSubmitBody func(Request, map[string]any) map[string]any
|
||||
}
|
||||
|
||||
type providerTaskClient struct {
|
||||
HTTPClient *http.Client
|
||||
Spec providerTaskSpec
|
||||
}
|
||||
|
||||
func (c providerTaskClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
if request.Kind != "images.generations" && request.Kind != "images.edits" && request.Kind != "videos.generations" {
|
||||
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported " + c.Spec.Name + " request kind", Retryable: false}
|
||||
}
|
||||
startedAt := time.Now()
|
||||
payload := cloneBody(request.Body)
|
||||
if c.Spec.DefaultSubmitBody != nil {
|
||||
payload = c.Spec.DefaultSubmitBody(request, payload)
|
||||
} else {
|
||||
payload["model"] = upstreamModelName(request.Candidate)
|
||||
}
|
||||
|
||||
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
||||
requestID := upstreamTaskID
|
||||
var submitResult map[string]any
|
||||
if upstreamTaskID == "" {
|
||||
result, id, err := c.submit(ctx, request, payload)
|
||||
if err != nil {
|
||||
return Response{}, annotateResponseError(err, id, startedAt, time.Now())
|
||||
}
|
||||
submitResult = result
|
||||
requestID = firstNonEmptyString(id, requestIDFromResult(result))
|
||||
if isProviderTaskFailure(c.Spec, result) {
|
||||
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
|
||||
}
|
||||
if isProviderTaskSuccess(c.Spec, result) && hasProviderTaskResult(result) {
|
||||
return Response{
|
||||
Result: normalizeProviderTaskResult(request, c.Spec, result, ""),
|
||||
RequestID: requestID,
|
||||
Progress: providerProgress(request),
|
||||
ResponseStartedAt: startedAt,
|
||||
ResponseFinishedAt: time.Now(),
|
||||
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||
}, nil
|
||||
}
|
||||
upstreamTaskID = providerTaskID(c.Spec, result)
|
||||
if upstreamTaskID == "" {
|
||||
return Response{}, &ClientError{Code: "invalid_response", Message: c.Spec.Name + " task id is missing", RequestID: requestID, Retryable: false}
|
||||
}
|
||||
if request.OnRemoteTaskSubmitted != nil {
|
||||
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
}
|
||||
} else if request.RemoteTaskPayload != nil {
|
||||
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
|
||||
payload = existingPayload
|
||||
}
|
||||
}
|
||||
|
||||
interval := providerPollInterval(request)
|
||||
timeout := providerPollTimeout(request)
|
||||
deadline := time.NewTimer(timeout)
|
||||
defer deadline.Stop()
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastResult map[string]any
|
||||
for {
|
||||
pollStarted := time.Now()
|
||||
result, pollRequestID, err := c.poll(ctx, request, upstreamTaskID, payload)
|
||||
pollFinished := time.Now()
|
||||
if err != nil {
|
||||
return Response{}, annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
|
||||
}
|
||||
lastResult = result
|
||||
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
|
||||
if isProviderTaskSuccess(c.Spec, result) {
|
||||
finishedAt := time.Now()
|
||||
return Response{
|
||||
Result: normalizeProviderTaskResult(request, c.Spec, result, upstreamTaskID),
|
||||
RequestID: requestID,
|
||||
Progress: append(providerProgress(request), Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}}),
|
||||
ResponseStartedAt: startedAt,
|
||||
ResponseFinishedAt: finishedAt,
|
||||
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
|
||||
}, nil
|
||||
}
|
||||
if isProviderTaskFailure(c.Spec, result) {
|
||||
return Response{}, providerTaskFailure(c.Spec, result, requestID, startedAt)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
||||
case <-deadline.C:
|
||||
return Response{}, &ClientError{Code: "timeout", Message: fmt.Sprintf("%s task %s did not finish before timeout; last status: %s", c.Spec.Name, upstreamTaskID, providerTaskStatus(c.Spec, lastResult)), RequestID: requestID, Retryable: true}
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c providerTaskClient) submit(ctx context.Context, request Request, payload map[string]any) (map[string]any, string, error) {
|
||||
path := c.Spec.SubmitPath(request, payload)
|
||||
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, path), payload, request.Candidate.Credentials, c.Spec.Auth)
|
||||
}
|
||||
|
||||
func (c providerTaskClient) poll(ctx context.Context, request Request, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
|
||||
path := resolveProviderPathTemplate(c.Spec.PollPath(request, upstreamTaskID, payload), upstreamTaskID)
|
||||
url := path
|
||||
if !strings.HasPrefix(path, "http://") && !strings.HasPrefix(path, "https://") {
|
||||
url = providerURL(request.Candidate.BaseURL, path)
|
||||
}
|
||||
if c.Spec.Name == "jimeng" {
|
||||
body := map[string]any{"task_id": upstreamTaskID, "req_key": upstreamModelName(request.Candidate)}
|
||||
return providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, body, request.Candidate.Credentials, c.Spec.Auth)
|
||||
}
|
||||
return providerGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), url, request.Candidate.Credentials, c.Spec.Auth)
|
||||
}
|
||||
|
||||
func providerPostJSON(ctx context.Context, client *http.Client, url string, body map[string]any, credentials map[string]any, auth string) (map[string]any, string, error) {
|
||||
raw, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
applyProviderAuth(req, credentials, auth)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||
}
|
||||
requestID := requestIDFromHTTPResponse(resp)
|
||||
result, err := decodeHTTPResponse(resp)
|
||||
return result, requestID, err
|
||||
}
|
||||
|
||||
func providerGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any, auth string) (map[string]any, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
applyProviderAuth(req, credentials, auth)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||
}
|
||||
requestID := requestIDFromHTTPResponse(resp)
|
||||
result, err := decodeHTTPResponse(resp)
|
||||
return result, requestID, err
|
||||
}
|
||||
|
||||
func applyProviderAuth(req *http.Request, credentials map[string]any, auth string) {
|
||||
apiKey := credential(credentials, "apiKey", "api_key", "key", "token")
|
||||
switch auth {
|
||||
case "token":
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Token "+apiKey)
|
||||
}
|
||||
case "x-key":
|
||||
if apiKey != "" {
|
||||
req.Header.Set("x-key", apiKey)
|
||||
}
|
||||
case "none":
|
||||
default:
|
||||
if apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func providerURL(base string, path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") && !strings.HasPrefix(path, "?") {
|
||||
path = "/" + path
|
||||
}
|
||||
return joinURL(base, path)
|
||||
}
|
||||
|
||||
func resolveProviderPathTemplate(path string, upstreamTaskID string) string {
|
||||
replacements := [][2]string{
|
||||
{"${upstream_task_id}", upstreamTaskID},
|
||||
{"{{upstream_task_id}}", upstreamTaskID},
|
||||
{"{upstream_task_id}", upstreamTaskID},
|
||||
{"${task_id}", upstreamTaskID},
|
||||
{"{{task_id}}", upstreamTaskID},
|
||||
{"{task_id}", upstreamTaskID},
|
||||
{"${taskId}", upstreamTaskID},
|
||||
{"${taskID}", upstreamTaskID},
|
||||
{"{{taskId}}", upstreamTaskID},
|
||||
{"{{taskID}}", upstreamTaskID},
|
||||
{"{taskId}", upstreamTaskID},
|
||||
{"{taskID}", upstreamTaskID},
|
||||
}
|
||||
for _, replacement := range replacements {
|
||||
path = strings.ReplaceAll(path, replacement[0], replacement[1])
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func providerTaskID(spec providerTaskSpec, result map[string]any) string {
|
||||
paths := append([]string{}, spec.TaskIDPaths...)
|
||||
paths = append(paths, "task_id", "taskId", "id", "job_id", "Response.JobId", "output.task_id", "data.task_id", "polling_url")
|
||||
for _, path := range paths {
|
||||
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func providerTaskStatus(spec providerTaskSpec, result map[string]any) string {
|
||||
if result == nil {
|
||||
return ""
|
||||
}
|
||||
if value, ok := valueAtPath(result, "status").(float64); ok {
|
||||
if value == 2 {
|
||||
return "success"
|
||||
}
|
||||
if value == 3 {
|
||||
return "failed"
|
||||
}
|
||||
return "process"
|
||||
}
|
||||
paths := append([]string{}, spec.StatusPaths...)
|
||||
paths = append(paths, "status", "state", "task_status", "output.task_status", "Response.Status", "data.status")
|
||||
for _, path := range paths {
|
||||
if value := stringFromPathValue(valueAtPath(result, path)); value != "" {
|
||||
return strings.ToLower(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringFromPathValue(value any) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
text := strings.TrimSpace(fmt.Sprint(value))
|
||||
if text == "" || text == "<nil>" {
|
||||
return ""
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func isProviderTaskSuccess(spec providerTaskSpec, result map[string]any) bool {
|
||||
return containsStatus(append([]string{"success", "succeeded", "completed", "complete", "done", "ready", "succeed", "succeeded", "suceeded", "done", "done"}, spec.SuccessStatuses...), providerTaskStatus(spec, result))
|
||||
}
|
||||
|
||||
func isProviderTaskFailure(spec providerTaskSpec, result map[string]any) bool {
|
||||
return containsStatus(append([]string{"failed", "failure", "error", "cancelled", "canceled", "fail", "expired", "task not found"}, spec.FailureStatuses...), providerTaskStatus(spec, result))
|
||||
}
|
||||
|
||||
func containsStatus(values []string, status string) bool {
|
||||
status = strings.ToLower(strings.TrimSpace(status))
|
||||
for _, value := range values {
|
||||
if strings.ToLower(strings.TrimSpace(value)) == status {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasProviderTaskResult(result map[string]any) bool {
|
||||
return result["data"] != nil || valueAtPath(result, "output.image_urls") != nil || valueAtPath(result, "output.video_url") != nil || valueAtPath(result, "Response.ResultVideoUrl") != nil || valueAtPath(result, "Response.ResultImages") != nil || result["urls"] != nil
|
||||
}
|
||||
|
||||
func normalizeProviderTaskResult(request Request, spec providerTaskSpec, result map[string]any, upstreamTaskID string) map[string]any {
|
||||
out := cloneMapAny(result)
|
||||
out["status"] = "success"
|
||||
if upstreamTaskID != "" {
|
||||
out["upstream_task_id"] = upstreamTaskID
|
||||
}
|
||||
if out["created"] == nil {
|
||||
out["created"] = time.Now().UnixMilli()
|
||||
}
|
||||
if out["model"] == nil {
|
||||
out["model"] = request.Model
|
||||
}
|
||||
if _, ok := out["data"].([]any); !ok {
|
||||
if out["data"] != nil {
|
||||
out["raw_data"] = out["data"]
|
||||
}
|
||||
out["data"] = providerTaskData(request, result)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func providerTaskData(request Request, result map[string]any) []any {
|
||||
fileType := "image"
|
||||
if request.Kind == "videos.generations" || strings.Contains(request.ModelType, "video") {
|
||||
fileType = "video"
|
||||
}
|
||||
urlValues := []any{}
|
||||
for _, path := range []string{
|
||||
"urls",
|
||||
"image_urls",
|
||||
"data.image_urls",
|
||||
"data.images",
|
||||
"output.image_urls",
|
||||
"output.video_url",
|
||||
"output.output",
|
||||
"data.output",
|
||||
"data.video_url",
|
||||
"video_url",
|
||||
"preview_url",
|
||||
"Response.ResultImages",
|
||||
"Response.ResultVideoUrl",
|
||||
} {
|
||||
appendURLValues(&urlValues, valueAtPath(result, path))
|
||||
}
|
||||
data := make([]any, 0, len(urlValues))
|
||||
for _, raw := range urlValues {
|
||||
if url := strings.TrimSpace(fmt.Sprint(raw)); url != "" {
|
||||
data = append(data, map[string]any{"type": fileType, "url": url})
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
if base64Values := valueAtPath(result, "data.binary_data_base64"); base64Values != nil {
|
||||
values := []any{}
|
||||
appendURLValues(&values, base64Values)
|
||||
for _, raw := range values {
|
||||
if content := strings.TrimSpace(fmt.Sprint(raw)); content != "" {
|
||||
data = append(data, map[string]any{"type": fileType, "content": content, "uploaded": false})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func appendURLValues(out *[]any, value any) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
case string:
|
||||
*out = append(*out, typed)
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
appendURLValues(out, item)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range typed {
|
||||
*out = append(*out, item)
|
||||
}
|
||||
case map[string]any:
|
||||
for _, key := range []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "content", "output"} {
|
||||
if item := strings.TrimSpace(fmt.Sprint(typed[key])); item != "" && item != "<nil>" {
|
||||
*out = append(*out, item)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func providerTaskFailure(spec providerTaskSpec, result map[string]any, requestID string, startedAt time.Time) error {
|
||||
message := firstNonEmptyString(valueAtPath(result, "message"), valueAtPath(result, "error.message"), valueAtPath(result, "error"), valueAtPath(result, "Response.ErrorMessage"), valueAtPath(result, "comment"), spec.Name+" task failed")
|
||||
return &ClientError{
|
||||
Code: firstNonEmptyString(valueAtPath(result, "code"), valueAtPath(result, "error_code"), valueAtPath(result, "Response.ErrorCode"), "provider_failed"),
|
||||
Message: message,
|
||||
RequestID: requestID,
|
||||
ResponseStartedAt: startedAt,
|
||||
ResponseFinishedAt: time.Now(),
|
||||
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||
Retryable: false,
|
||||
}
|
||||
}
|
||||
|
||||
func providerPollInterval(request Request) time.Duration {
|
||||
return durationFromConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
|
||||
}
|
||||
|
||||
func providerPollTimeout(request Request) time.Duration {
|
||||
return durationFromConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
|
||||
}
|
||||
|
||||
func durationFromConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
|
||||
for _, key := range keys {
|
||||
switch value := config[key].(type) {
|
||||
case int:
|
||||
if value > 0 {
|
||||
return time.Duration(value) * time.Millisecond
|
||||
}
|
||||
case int64:
|
||||
if value > 0 {
|
||||
return time.Duration(value) * time.Millisecond
|
||||
}
|
||||
case float64:
|
||||
if value > 0 {
|
||||
return time.Duration(value) * time.Millisecond
|
||||
}
|
||||
case string:
|
||||
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func valueAtPath(values map[string]any, path string) any {
|
||||
if values == nil || strings.TrimSpace(path) == "" {
|
||||
return nil
|
||||
}
|
||||
var current any = values
|
||||
for _, part := range strings.Split(path, ".") {
|
||||
object, ok := current.(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
current = object[part]
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func mediaPromptText(body map[string]any) string {
|
||||
if prompt := strings.TrimSpace(stringFromAny(body["prompt"])); prompt != "" {
|
||||
return prompt
|
||||
}
|
||||
content, _ := body["content"].([]any)
|
||||
for _, item := range content {
|
||||
if part, ok := item.(map[string]any); ok && strings.TrimSpace(stringFromAny(part["type"])) == "text" {
|
||||
if text := strings.TrimSpace(stringFromAny(part["text"])); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
339
apps/api/internal/clients/provider_task_test.go
Normal file
339
apps/api/internal/clients/provider_task_test.go
Normal file
@ -0,0 +1,339 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestProviderTaskClientsSubmitAndPoll(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
client Client
|
||||
provider string
|
||||
specType string
|
||||
submitMatch func(*http.Request) bool
|
||||
submitResponse string
|
||||
pollMatch func(*http.Request) bool
|
||||
pollResponse string
|
||||
authHeader string
|
||||
resultURL string
|
||||
}{
|
||||
{
|
||||
name: "jimeng",
|
||||
client: JimengClient{},
|
||||
provider: "jimeng",
|
||||
specType: "jimeng",
|
||||
submitMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask"
|
||||
},
|
||||
submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`,
|
||||
pollMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult"
|
||||
},
|
||||
pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/jimeng.mp4",
|
||||
},
|
||||
{
|
||||
name: "blackforest",
|
||||
client: BlackforestClient{},
|
||||
provider: "blackforest",
|
||||
specType: "blackforest",
|
||||
submitMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model")
|
||||
},
|
||||
submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`,
|
||||
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" },
|
||||
pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`,
|
||||
authHeader: "test-key",
|
||||
resultURL: "https://cdn.example/flux.png",
|
||||
},
|
||||
{
|
||||
name: "hunyuan-image",
|
||||
client: HunyuanImageClient{},
|
||||
provider: "tencent-hunyuan-image",
|
||||
specType: "tencent-hunyuan-image",
|
||||
submitMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob"
|
||||
},
|
||||
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
|
||||
pollMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob"
|
||||
},
|
||||
pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/hunyuan.png",
|
||||
},
|
||||
{
|
||||
name: "hunyuan-video",
|
||||
client: HunyuanVideoClient{},
|
||||
provider: "tencent-hunyuan-video",
|
||||
specType: "tencent-hunyuan-video",
|
||||
submitMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob"
|
||||
},
|
||||
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
|
||||
pollMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob"
|
||||
},
|
||||
pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/hunyuan.mp4",
|
||||
},
|
||||
{
|
||||
name: "minimax",
|
||||
client: MinimaxClient{},
|
||||
provider: "minimax",
|
||||
specType: "minimax",
|
||||
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" },
|
||||
submitResponse: `{"task_id":123}`,
|
||||
pollMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123"
|
||||
},
|
||||
pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/minimax.mp4",
|
||||
},
|
||||
{
|
||||
name: "midjourney",
|
||||
client: MidjourneyClient{},
|
||||
provider: "midjourney",
|
||||
specType: "midjourney",
|
||||
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" },
|
||||
submitResponse: `{"job_id":"remote-1"}`,
|
||||
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" },
|
||||
pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/mj.png",
|
||||
},
|
||||
{
|
||||
name: "vidu",
|
||||
client: ViduClient{},
|
||||
provider: "vidu",
|
||||
specType: "vidu",
|
||||
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" },
|
||||
submitResponse: `{"task_id":"remote-1"}`,
|
||||
pollMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations"
|
||||
},
|
||||
pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`,
|
||||
authHeader: "Token test-key",
|
||||
resultURL: "https://cdn.example/vidu.mp4",
|
||||
},
|
||||
{
|
||||
name: "aliyun-bailian",
|
||||
client: AliyunBailianClient{},
|
||||
provider: "aliyun-bailian",
|
||||
specType: "aliyun-bailian",
|
||||
submitMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis"
|
||||
},
|
||||
submitResponse: `{"output":{"task_id":"remote-1"}}`,
|
||||
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" },
|
||||
pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/aliyun.mp4",
|
||||
},
|
||||
{
|
||||
name: "newapi",
|
||||
client: NewAPIClient{},
|
||||
provider: "newapi",
|
||||
specType: "newapi",
|
||||
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" },
|
||||
submitResponse: `{"task_id":"remote-1"}`,
|
||||
pollMatch: func(r *http.Request) bool {
|
||||
return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1"
|
||||
},
|
||||
pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`,
|
||||
authHeader: "Bearer test-key",
|
||||
resultURL: "https://cdn.example/newapi.mp4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if tc.authHeader == "test-key" {
|
||||
if r.Header.Get("x-key") != tc.authHeader {
|
||||
t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key"))
|
||||
}
|
||||
} else if r.Header.Get("Authorization") != tc.authHeader {
|
||||
t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("x-request-id", "req-"+tc.name)
|
||||
switch {
|
||||
case tc.submitMatch(r):
|
||||
_, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host)))
|
||||
case tc.pollMatch(r):
|
||||
_, _ = w.Write([]byte(tc.pollResponse))
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
var submittedRemoteTaskID string
|
||||
request := Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "alias-model",
|
||||
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
Provider: tc.provider,
|
||||
SpecType: tc.specType,
|
||||
BaseURL: server.URL,
|
||||
Credentials: map[string]any{"apiKey": "test-key"},
|
||||
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||
ModelName: "alias-model",
|
||||
ProviderModelName: "provider-model",
|
||||
ModelType: "video_generate",
|
||||
ClientID: tc.name + ":test",
|
||||
},
|
||||
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
||||
submittedRemoteTaskID = remoteTaskID
|
||||
if payload["payload"] == nil || payload["submit"] == nil {
|
||||
t.Fatalf("missing remote payload: %#v", payload)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
response, err := tc.client.Run(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("run failed: %v", err)
|
||||
}
|
||||
data, ok := response.Result["data"].([]any)
|
||||
if !ok || len(data) == 0 {
|
||||
t.Fatalf("missing data: %#v", response.Result)
|
||||
}
|
||||
first, _ := data[0].(map[string]any)
|
||||
if first["url"] != tc.resultURL {
|
||||
t.Fatalf("unexpected result url: %#v", response.Result)
|
||||
}
|
||||
if response.RequestID != "req-"+tc.name {
|
||||
t.Fatalf("unexpected request id: %q", response.RequestID)
|
||||
}
|
||||
if submittedRemoteTaskID == "" {
|
||||
t.Fatalf("expected remote task submission")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) {
|
||||
t.Run("poll failure", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("x-request-id", "req-failed")
|
||||
switch r.URL.Path {
|
||||
case "/videos/generations":
|
||||
_, _ = w.Write([]byte(`{"task_id":"remote-1"}`))
|
||||
case "/videos/generations/remote-1":
|
||||
_, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`))
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s", r.URL.String())
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (NewAPIClient{}).Run(context.Background(), Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "alias-model",
|
||||
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
Provider: "newapi",
|
||||
SpecType: "newapi",
|
||||
BaseURL: server.URL,
|
||||
Credentials: map[string]any{"apiKey": "test-key"},
|
||||
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||
ModelName: "alias-model",
|
||||
ProviderModelName: "provider-model",
|
||||
ModelType: "video_generate",
|
||||
},
|
||||
})
|
||||
var clientErr *ClientError
|
||||
if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable {
|
||||
t.Fatalf("expected non-retryable upstream failure, got %#v", err)
|
||||
}
|
||||
if clientErr.RequestID != "req-failed" {
|
||||
t.Fatalf("unexpected request id: %q", clientErr.RequestID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("submit rate limit", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("x-request-id", "req-rate-limit")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (NewAPIClient{}).Run(context.Background(), Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "alias-model",
|
||||
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
Provider: "newapi",
|
||||
SpecType: "newapi",
|
||||
BaseURL: server.URL,
|
||||
Credentials: map[string]any{"apiKey": "test-key"},
|
||||
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||
ModelName: "alias-model",
|
||||
ProviderModelName: "provider-model",
|
||||
ModelType: "video_generate",
|
||||
},
|
||||
})
|
||||
var clientErr *ClientError
|
||||
if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" {
|
||||
t.Fatalf("expected retryable rate limit with request id, got %#v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodPost {
|
||||
t.Fatal("submit should not run for resumed remote task")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
response, err := (NewAPIClient{}).Run(context.Background(), Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "alias-model",
|
||||
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
||||
RemoteTaskID: "remote-1",
|
||||
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
Provider: "newapi",
|
||||
SpecType: "newapi",
|
||||
BaseURL: server.URL,
|
||||
Credentials: map[string]any{"apiKey": "test-key"},
|
||||
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
||||
ModelName: "alias-model",
|
||||
ProviderModelName: "provider-model",
|
||||
ModelType: "video_generate",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("run failed: %v", err)
|
||||
}
|
||||
data := response.Result["data"].([]any)
|
||||
first := data[0].(map[string]any)
|
||||
if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" {
|
||||
t.Fatalf("unexpected response: %#v", response.Result)
|
||||
}
|
||||
}
|
||||
481
apps/api/internal/clients/universal.go
Normal file
481
apps/api/internal/clients/universal.go
Normal file
@ -0,0 +1,481 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||
)
|
||||
|
||||
type UniversalClient struct {
|
||||
HTTPClient *http.Client
|
||||
ScriptExecutor *scriptengine.Executor
|
||||
}
|
||||
|
||||
func (c UniversalClient) Run(ctx context.Context, request Request) (Response, error) {
|
||||
executor := c.ScriptExecutor
|
||||
if executor == nil {
|
||||
executor = &scriptengine.Executor{}
|
||||
}
|
||||
startedAt := time.Now()
|
||||
modelType := strings.TrimSpace(request.ModelType)
|
||||
if modelType == "" {
|
||||
modelType = strings.TrimSpace(request.Candidate.ModelType)
|
||||
}
|
||||
|
||||
payload := cloneBody(request.Body)
|
||||
upstreamTaskID := strings.TrimSpace(request.RemoteTaskID)
|
||||
submitRequestID := upstreamTaskID
|
||||
var submitResult map[string]any
|
||||
|
||||
if upstreamTaskID == "" {
|
||||
var err error
|
||||
payload, err = c.universalGetParams(ctx, executor, request, modelType)
|
||||
if err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
submitResult, submitRequestID, err = c.universalSubmit(ctx, executor, request, modelType, payload)
|
||||
if err != nil {
|
||||
return Response{}, annotateResponseError(err, submitRequestID, startedAt, time.Now())
|
||||
}
|
||||
if isUniversalSuccess(submitResult) && submitResult["data"] != nil {
|
||||
return Response{
|
||||
Result: normalizeUniversalResult(request, submitResult, ""),
|
||||
RequestID: firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)),
|
||||
Progress: providerProgress(request),
|
||||
ResponseStartedAt: startedAt,
|
||||
ResponseFinishedAt: time.Now(),
|
||||
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||
}, nil
|
||||
}
|
||||
if isUniversalFailure(submitResult) {
|
||||
return Response{}, universalFailureError(submitResult, firstNonEmptyString(submitRequestID, requestIDFromResult(submitResult)), startedAt)
|
||||
}
|
||||
upstreamTaskID = universalTaskID(submitResult)
|
||||
if upstreamTaskID == "" {
|
||||
return Response{}, &ClientError{Code: "invalid_response", Message: "universal task id is missing", RequestID: submitRequestID, Retryable: false}
|
||||
}
|
||||
if request.OnRemoteTaskSubmitted != nil {
|
||||
if err := request.OnRemoteTaskSubmitted(upstreamTaskID, map[string]any{"payload": payload, "submit": submitResult}); err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
}
|
||||
} else if request.RemoteTaskPayload != nil {
|
||||
if existingPayload, ok := request.RemoteTaskPayload["payload"].(map[string]any); ok {
|
||||
payload = existingPayload
|
||||
}
|
||||
}
|
||||
|
||||
result, requestID, err := c.universalPollUntilDone(ctx, executor, request, modelType, upstreamTaskID, payload, firstNonEmptyString(submitRequestID, upstreamTaskID), startedAt)
|
||||
if err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
finishedAt := time.Now()
|
||||
return Response{
|
||||
Result: normalizeUniversalResult(request, result, upstreamTaskID),
|
||||
RequestID: firstNonEmptyString(requestID, submitRequestID, requestIDFromResult(result), upstreamTaskID),
|
||||
Progress: universalProgress(request, upstreamTaskID),
|
||||
ResponseStartedAt: startedAt,
|
||||
ResponseFinishedAt: finishedAt,
|
||||
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c UniversalClient) universalGetParams(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string) (map[string]any, error) {
|
||||
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customGetParamsScript", "custom_get_params_script"); scriptText != "" {
|
||||
scriptContext := universalScriptContext(request, modelType, nil)
|
||||
out, err := executor.Execute(ctx, scriptengine.Options{
|
||||
Script: scriptText,
|
||||
Args: []any{cloneBody(request.Body), scriptContext},
|
||||
ContextData: scriptContext,
|
||||
ScriptName: "custom_get_params_script:" + modelType,
|
||||
PreferredEntryNames: []string{"getGenerateParams", "getParams", "main", "handler"},
|
||||
Timeout: 30 * time.Second,
|
||||
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, universalScriptError(err)
|
||||
}
|
||||
if params, ok := out.(map[string]any); ok && params != nil {
|
||||
if params["_originalParams"] == nil {
|
||||
params["_originalParams"] = cloneBody(request.Body)
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
return nil, &ClientError{Code: "invalid_response", Message: "custom get params script must return an object", Retryable: false}
|
||||
}
|
||||
body := universalDefaultPayload(request)
|
||||
body["_originalParams"] = cloneBody(request.Body)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (c UniversalClient) universalSubmit(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, payload map[string]any) (map[string]any, string, error) {
|
||||
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customSubmitScript", "custom_submit_script"); scriptText != "" {
|
||||
scriptContext := universalScriptContext(request, modelType, payload)
|
||||
out, err := executor.Execute(ctx, scriptengine.Options{
|
||||
Script: scriptText,
|
||||
Args: []any{cloneBody(payload), scriptContext},
|
||||
ContextData: scriptContext,
|
||||
ScriptName: "custom_submit_script:" + modelType,
|
||||
PreferredEntryNames: []string{"submitTask", "submitParams", "submit", "main", "handler"},
|
||||
Timeout: 30 * time.Second,
|
||||
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", universalScriptError(err)
|
||||
}
|
||||
result, ok := out.(map[string]any)
|
||||
if !ok || result == nil {
|
||||
return nil, "", &ClientError{Code: "invalid_response", Message: "custom submit script must return an object", Retryable: false}
|
||||
}
|
||||
return result, requestIDFromResult(result), nil
|
||||
}
|
||||
endpoint := universalSubmitEndpoint(request)
|
||||
result, requestID, err := universalPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), request.Candidate.BaseURL, endpoint, universalStripPrivatePayload(payload), request.Candidate.Credentials)
|
||||
return result, requestID, err
|
||||
}
|
||||
|
||||
func (c UniversalClient) universalPollUntilDone(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any, requestID string, startedAt time.Time) (map[string]any, string, error) {
|
||||
interval := universalDurationConfig(request.Candidate.PlatformConfig, 2*time.Second, "pollIntervalMs", "poll_interval_ms")
|
||||
timeout := universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms", "timeoutMs")
|
||||
deadline := time.NewTimer(timeout)
|
||||
defer deadline.Stop()
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastResult map[string]any
|
||||
for {
|
||||
pollStarted := time.Now()
|
||||
result, pollRequestID, err := c.universalPoll(ctx, executor, request, modelType, upstreamTaskID, payload)
|
||||
pollFinished := time.Now()
|
||||
if err != nil {
|
||||
return nil, "", annotateResponseError(err, firstNonEmptyString(pollRequestID, requestID, upstreamTaskID), pollStarted, pollFinished)
|
||||
}
|
||||
lastResult = result
|
||||
requestID = firstNonEmptyString(pollRequestID, requestID, requestIDFromResult(result), upstreamTaskID)
|
||||
if isUniversalSuccess(result) {
|
||||
return result, requestID, nil
|
||||
}
|
||||
if isUniversalFailure(result) {
|
||||
return nil, "", universalFailureError(result, requestID, startedAt)
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, "", &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
|
||||
case <-deadline.C:
|
||||
return nil, "", &ClientError{Code: "timeout", Message: fmt.Sprintf("universal task %s did not finish before timeout; last status: %s", upstreamTaskID, universalStatus(lastResult)), RequestID: requestID, Retryable: true}
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c UniversalClient) universalPoll(ctx context.Context, executor *scriptengine.Executor, request Request, modelType string, upstreamTaskID string, payload map[string]any) (map[string]any, string, error) {
|
||||
if scriptText := universalSceneScript(request.Candidate.PlatformConfig, modelType, "customPollScript", "custom_poll_script"); scriptText != "" {
|
||||
scriptContext := universalScriptContext(request, modelType, payload)
|
||||
out, err := executor.Execute(ctx, scriptengine.Options{
|
||||
Script: scriptText,
|
||||
Args: []any{upstreamTaskID, scriptContext},
|
||||
ContextData: scriptContext,
|
||||
ScriptName: "custom_poll_script:" + modelType,
|
||||
PreferredEntryNames: []string{"pollTask", "poll", "main", "handler"},
|
||||
Timeout: 30 * time.Second,
|
||||
HTTPClient: httpClient(request.HTTPClient, c.HTTPClient),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, "", universalScriptError(err)
|
||||
}
|
||||
result, ok := out.(map[string]any)
|
||||
if !ok || result == nil {
|
||||
return nil, "", &ClientError{Code: "invalid_response", Message: "custom poll script must return an object", Retryable: false}
|
||||
}
|
||||
return result, requestIDFromResult(result), nil
|
||||
}
|
||||
pollURL := resolveUniversalTaskURL(request.Candidate.PlatformConfig, upstreamTaskID)
|
||||
if pollURL == "" {
|
||||
return nil, "", &ClientError{Code: "missing_configuration", Message: "universal getTaskURL is required", Retryable: false}
|
||||
}
|
||||
return universalGetJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), pollURL, request.Candidate.Credentials)
|
||||
}
|
||||
|
||||
func universalScriptContext(request Request, modelType string, payload map[string]any) map[string]any {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(request.Candidate.BaseURL), "/")
|
||||
getTaskURL := universalConfigString(request.Candidate.PlatformConfig, "getTaskURL", "get_task_url")
|
||||
context := map[string]any{
|
||||
"__easyaiScriptContext": true,
|
||||
"baseURL": baseURL,
|
||||
"getTaskURL": getTaskURL,
|
||||
"authValues": cloneMapAny(request.Candidate.Credentials),
|
||||
"headers": map[string]any{},
|
||||
"payload": cloneMapAny(payload),
|
||||
"type": modelType,
|
||||
"options": map[string]any{
|
||||
"task_id": request.RemoteTaskID,
|
||||
"upstream_task_id": request.RemoteTaskID,
|
||||
"model": request.Model,
|
||||
"providerModelName": request.Candidate.ProviderModelName,
|
||||
"platformId": request.Candidate.PlatformID,
|
||||
"platformModelId": request.Candidate.PlatformModelID,
|
||||
"canonicalModelKey": request.Candidate.CanonicalModelKey,
|
||||
"modelType": modelType,
|
||||
"timeout": universalDurationConfig(request.Candidate.PlatformConfig, 10*time.Minute, "pollTimeoutMs", "poll_timeout_ms").Milliseconds(),
|
||||
},
|
||||
"env": cloneMapAny(request.Candidate.PlatformConfig),
|
||||
"candidate": universalCandidateSnapshot(request),
|
||||
}
|
||||
context["createRequestURL"] = func(path string, base ...string) string {
|
||||
selectedBase := baseURL
|
||||
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
|
||||
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
|
||||
}
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
return selectedBase + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
context["creatRequestURL"] = context["createRequestURL"]
|
||||
context["resolveGetTaskURL"] = func(taskID string) string {
|
||||
return resolveUniversalTaskURL(request.Candidate.PlatformConfig, taskID)
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
func universalCandidateSnapshot(request Request) map[string]any {
|
||||
return map[string]any{
|
||||
"modelName": request.Candidate.ModelName,
|
||||
"modelAlias": request.Candidate.ModelAlias,
|
||||
"providerModelName": request.Candidate.ProviderModelName,
|
||||
"provider": request.Candidate.Provider,
|
||||
"platformId": request.Candidate.PlatformID,
|
||||
"platformModelId": request.Candidate.PlatformModelID,
|
||||
"capabilities": cloneMapAny(request.Candidate.Capabilities),
|
||||
}
|
||||
}
|
||||
|
||||
func universalDefaultPayload(request Request) map[string]any {
|
||||
body := cloneBody(request.Body)
|
||||
body["model"] = upstreamModelName(request.Candidate)
|
||||
if request.Kind == "images.generations" {
|
||||
if n := firstPresent(body["n"], body["numImages"]); n != nil {
|
||||
body["numImages"] = n
|
||||
}
|
||||
if aspectRatio := strings.TrimSpace(stringFromAny(body["aspect_ratio"])); aspectRatio != "" {
|
||||
body["aspectRatio"] = aspectRatio
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func universalSubmitEndpoint(request Request) string {
|
||||
if endpoint := universalConfigString(request.Candidate.PlatformConfig, "submitPath", "submit_path"); endpoint != "" {
|
||||
return endpoint
|
||||
}
|
||||
switch request.Kind {
|
||||
case "images.generations":
|
||||
return "/images/generations"
|
||||
case "images.edits":
|
||||
return "/images/edits"
|
||||
case "videos.generations":
|
||||
return "/video/generations"
|
||||
default:
|
||||
return "/" + strings.ReplaceAll(request.Kind, ".", "/")
|
||||
}
|
||||
}
|
||||
|
||||
func universalPostJSON(ctx context.Context, client *http.Client, baseURL string, endpoint string, body map[string]any, credentials map[string]any) (map[string]any, string, error) {
|
||||
raw, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, providerURL(baseURL, endpoint), bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||
}
|
||||
requestID := requestIDFromHTTPResponse(resp)
|
||||
result, err := decodeHTTPResponse(resp)
|
||||
return result, requestID, err
|
||||
}
|
||||
|
||||
func universalGetJSON(ctx context.Context, client *http.Client, url string, credentials map[string]any) (map[string]any, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if apiKey := credential(credentials, "apiKey", "api_key", "key", "token"); apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
|
||||
}
|
||||
requestID := requestIDFromHTTPResponse(resp)
|
||||
result, err := decodeHTTPResponse(resp)
|
||||
return result, requestID, err
|
||||
}
|
||||
|
||||
func normalizeUniversalResult(request Request, result map[string]any, upstreamTaskID string) map[string]any {
|
||||
out := cloneMapAny(result)
|
||||
if out["created"] == nil {
|
||||
out["created"] = time.Now().UnixMilli()
|
||||
}
|
||||
if out["task_id"] == nil {
|
||||
out["task_id"] = upstreamTaskID
|
||||
}
|
||||
if out["upstream_task_id"] == nil {
|
||||
out["upstream_task_id"] = upstreamTaskID
|
||||
}
|
||||
if out["model"] == nil {
|
||||
out["model"] = request.Model
|
||||
}
|
||||
if out["status"] == nil {
|
||||
out["status"] = "success"
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func universalScriptError(err error) error {
|
||||
var scriptErr *scriptengine.Error
|
||||
if strings.TrimSpace(err.Error()) == "" {
|
||||
return &ClientError{Code: "script_error", Message: "script execution failed", Retryable: false}
|
||||
}
|
||||
if errors.As(err, &scriptErr) {
|
||||
return &ClientError{Code: scriptErr.ErrorCode(), Message: scriptErr.Error(), Retryable: scriptErr.ErrorCode() == "script_timeout"}
|
||||
}
|
||||
return &ClientError{Code: "script_error", Message: err.Error(), Retryable: false}
|
||||
}
|
||||
|
||||
func universalFailureError(result map[string]any, requestID string, startedAt time.Time) error {
|
||||
message := firstNonEmptyString(result["message"], result["error"], result["error_message"], "universal task failed")
|
||||
return &ClientError{
|
||||
Code: firstNonEmptyString(result["code"], result["error_code"], "provider_failed"),
|
||||
Message: message,
|
||||
RequestID: requestID,
|
||||
ResponseStartedAt: startedAt,
|
||||
ResponseFinishedAt: time.Now(),
|
||||
ResponseDurationMS: responseDurationMS(startedAt, time.Now()),
|
||||
Retryable: false,
|
||||
}
|
||||
}
|
||||
|
||||
func isUniversalSuccess(result map[string]any) bool {
|
||||
switch universalStatus(result) {
|
||||
case "success", "succeeded", "completed", "complete", "done":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isUniversalFailure(result map[string]any) bool {
|
||||
switch universalStatus(result) {
|
||||
case "failed", "failure", "error", "cancelled", "canceled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func universalStatus(result map[string]any) string {
|
||||
return strings.ToLower(strings.TrimSpace(firstNonEmptyString(result["status"], result["state"], result["task_status"])))
|
||||
}
|
||||
|
||||
func universalTaskID(result map[string]any) string {
|
||||
return firstNonEmptyString(result["upstream_task_id"], result["task_id"], result["taskId"], result["id"])
|
||||
}
|
||||
|
||||
func universalProgress(request Request, upstreamTaskID string) []Progress {
|
||||
progress := providerProgress(request)
|
||||
progress = append(progress, Progress{Phase: "polling", Progress: 0.65, Message: "provider task polled", Payload: map[string]any{"upstreamTaskId": upstreamTaskID}})
|
||||
return progress
|
||||
}
|
||||
|
||||
func universalStripPrivatePayload(payload map[string]any) map[string]any {
|
||||
out := cloneMapAny(payload)
|
||||
for _, key := range []string{"_originalParams", "_resolution", "_duration"} {
|
||||
delete(out, key)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func universalSceneScript(config map[string]any, modelType string, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
value := config[key]
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(typed) != "" {
|
||||
return strings.TrimSpace(typed)
|
||||
}
|
||||
case map[string]any:
|
||||
if script := firstNonEmptyString(typed[modelType], typed["common"]); script != "" {
|
||||
return script
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func universalConfigString(config map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func universalDurationConfig(config map[string]any, fallback time.Duration, keys ...string) time.Duration {
|
||||
for _, key := range keys {
|
||||
switch value := config[key].(type) {
|
||||
case int:
|
||||
if value > 0 {
|
||||
return time.Duration(value) * time.Millisecond
|
||||
}
|
||||
case int64:
|
||||
if value > 0 {
|
||||
return time.Duration(value) * time.Millisecond
|
||||
}
|
||||
case float64:
|
||||
if value > 0 {
|
||||
return time.Duration(value) * time.Millisecond
|
||||
}
|
||||
case string:
|
||||
if parsed, err := time.ParseDuration(value); err == nil && parsed > 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func resolveUniversalTaskURL(config map[string]any, upstreamTaskID string) string {
|
||||
template := universalConfigString(config, "getTaskURL", "get_task_url")
|
||||
out := strings.TrimSpace(template)
|
||||
replacements := [][2]string{
|
||||
{"${upstream_task_id}", upstreamTaskID},
|
||||
{"{{upstream_task_id}}", upstreamTaskID},
|
||||
{"{upstream_task_id}", upstreamTaskID},
|
||||
{"${task_id}", upstreamTaskID},
|
||||
{"{{task_id}}", upstreamTaskID},
|
||||
{"{task_id}", upstreamTaskID},
|
||||
{"${taskId}", upstreamTaskID},
|
||||
{"${taskID}", upstreamTaskID},
|
||||
{"{{taskId}}", upstreamTaskID},
|
||||
{"{{taskID}}", upstreamTaskID},
|
||||
{"{taskId}", upstreamTaskID},
|
||||
{"{taskID}", upstreamTaskID},
|
||||
}
|
||||
for _, replacement := range replacements {
|
||||
out = strings.ReplaceAll(out, replacement[0], replacement[1])
|
||||
}
|
||||
return out
|
||||
}
|
||||
132
apps/api/internal/clients/universal_test.go
Normal file
132
apps/api/internal/clients/universal_test.go
Normal file
@ -0,0 +1,132 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestUniversalClientRunsCustomScripts(t *testing.T) {
|
||||
request := Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "custom-video",
|
||||
Body: map[string]any{"model": "custom-video", "prompt": "hello"},
|
||||
Candidate: testUniversalCandidate(map[string]any{
|
||||
"customGetParamsScript": map[string]any{
|
||||
"video_generate": `async function getGenerateParams(params, context) {
|
||||
return { prompt: params.prompt + "-payload", model: context.candidate.providerModelName };
|
||||
}`,
|
||||
},
|
||||
"customSubmitScript": map[string]any{
|
||||
"video_generate": `async function submitTask(payload) {
|
||||
return { status: "submitted", task_id: "task-" + payload.prompt };
|
||||
}`,
|
||||
},
|
||||
"customPollScript": map[string]any{
|
||||
"video_generate": `async function pollTask(taskId) {
|
||||
return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/video.mp4" }] };
|
||||
}`,
|
||||
},
|
||||
"pollIntervalMs": 1,
|
||||
"pollTimeoutMs": 1000,
|
||||
}),
|
||||
}
|
||||
var submitted string
|
||||
request.OnRemoteTaskSubmitted = func(remoteTaskID string, payload map[string]any) error {
|
||||
submitted = remoteTaskID
|
||||
return nil
|
||||
}
|
||||
|
||||
response, err := (UniversalClient{}).Run(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("run failed: %v", err)
|
||||
}
|
||||
if submitted != "task-hello-payload" {
|
||||
t.Fatalf("unexpected remote task id: %q", submitted)
|
||||
}
|
||||
if response.Result["upstream_task_id"] != "task-hello-payload" {
|
||||
t.Fatalf("unexpected result: %#v", response.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniversalClientDefaultSubmitAndPoll(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||
t.Fatalf("missing authorization header")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/video/generations":
|
||||
_, _ = w.Write([]byte(`{"status":"submitted","task_id":"remote-1"}`))
|
||||
case "/tasks/remote-1":
|
||||
_, _ = w.Write([]byte(`{"status":"success","data":[{"url":"https://cdn.example/default.mp4"}]}`))
|
||||
default:
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
request := Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "default-video",
|
||||
Body: map[string]any{"model": "default-video", "prompt": "hello"},
|
||||
Candidate: testUniversalCandidate(map[string]any{
|
||||
"getTaskURL": server.URL + "/tasks/{{task_id}}",
|
||||
"pollIntervalMs": 1,
|
||||
"pollTimeoutMs": 1000,
|
||||
}),
|
||||
}
|
||||
request.Candidate.BaseURL = server.URL
|
||||
|
||||
response, err := (UniversalClient{}).Run(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("run failed: %v", err)
|
||||
}
|
||||
if response.Result["upstream_task_id"] != "remote-1" {
|
||||
t.Fatalf("unexpected result: %#v", response.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUniversalClientResumeSkipsSubmit(t *testing.T) {
|
||||
request := Request{
|
||||
Kind: "videos.generations",
|
||||
ModelType: "video_generate",
|
||||
Model: "resume-video",
|
||||
Body: map[string]any{"model": "resume-video", "prompt": "hello"},
|
||||
RemoteTaskID: "existing-1",
|
||||
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
|
||||
Candidate: testUniversalCandidate(map[string]any{
|
||||
"customSubmitScript": `async function submitTask() { throw new Error("submit should not run"); }`,
|
||||
"customPollScript": `async function pollTask(taskId) { return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/resume.mp4" }] }; }`,
|
||||
"pollIntervalMs": 1,
|
||||
"pollTimeoutMs": 1000,
|
||||
}),
|
||||
}
|
||||
|
||||
response, err := (UniversalClient{}).Run(context.Background(), request)
|
||||
if err != nil {
|
||||
t.Fatalf("run failed: %v", err)
|
||||
}
|
||||
if response.Result["upstream_task_id"] != "existing-1" {
|
||||
t.Fatalf("unexpected result: %#v", response.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func testUniversalCandidate(config map[string]any) store.RuntimeModelCandidate {
|
||||
return store.RuntimeModelCandidate{
|
||||
Provider: "universal",
|
||||
SpecType: "universal",
|
||||
BaseURL: "https://provider.example",
|
||||
Credentials: map[string]any{"apiKey": "test-key"},
|
||||
PlatformConfig: config,
|
||||
ModelName: "alias-model",
|
||||
ProviderModelName: "provider-model",
|
||||
ModelType: "video_generate",
|
||||
ClientID: "universal:test",
|
||||
}
|
||||
}
|
||||
196
apps/api/internal/runner/param_processor_script.go
Normal file
196
apps/api/internal/runner/param_processor_script.go
Normal file
@ -0,0 +1,196 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func (s *Service) preprocessRequestWithScripts(ctx context.Context, kind string, body map[string]any, candidate store.RuntimeModelCandidate) parameterPreprocessResult {
|
||||
if platformConfigBool(candidate.PlatformConfig, "skipParamNormalization", "skip_param_normalization") {
|
||||
modelType := strings.TrimSpace(candidate.ModelType)
|
||||
if modelType == "" {
|
||||
modelType = modelTypeFromKind(kind, body)
|
||||
}
|
||||
input := cloneMap(body)
|
||||
return parameterPreprocessResult{
|
||||
Body: cloneMap(body),
|
||||
Log: parameterPreprocessingLog{
|
||||
ModelType: modelType,
|
||||
Input: input,
|
||||
Output: cloneMap(body),
|
||||
Changed: false,
|
||||
Changes: []parameterPreprocessChange{},
|
||||
Model: preprocessingModelSnapshot(candidate),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result := preprocessRequestWithLog(kind, body, candidate)
|
||||
if result.Err != nil {
|
||||
return result
|
||||
}
|
||||
scriptText := platformConfigString(candidate.PlatformConfig, "customPreprocessScript", "custom_preprocess_script")
|
||||
if strings.TrimSpace(scriptText) == "" || s.scriptExecutor == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
before := cloneMap(result.Body)
|
||||
scriptContext := s.scriptContext(candidate, result.Log.ModelType, nil, map[string]any{
|
||||
"modelCapability": effectiveModelCapability(candidate),
|
||||
"platformModel": result.Log.Model,
|
||||
"platform": candidate.PlatformConfig,
|
||||
})
|
||||
out, err := s.scriptExecutor.Execute(ctx, scriptengine.Options{
|
||||
Script: scriptText,
|
||||
Args: []any{cloneMap(result.Body), result.Log.ModelType, scriptContext},
|
||||
ContextData: scriptContext,
|
||||
ScriptName: "custom_preprocess_script:" + result.Log.ModelType,
|
||||
PreferredEntryNames: []string{"preprocessParams", "preprocess", "main", "handler"},
|
||||
Timeout: scriptengine.PreprocessTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
result.Log.recordScriptChange("CustomPreprocessScript", "error", "$", before, result.Body, err.Error())
|
||||
result.Log.Output = cloneMap(result.Body)
|
||||
result.Log.Changed = len(result.Log.Changes) > 0
|
||||
result.Err = err
|
||||
return result
|
||||
}
|
||||
rewritten, ok := out.(map[string]any)
|
||||
if !ok || rewritten == nil {
|
||||
result.Log.Output = cloneMap(result.Body)
|
||||
result.Log.Changed = len(result.Log.Changes) > 0
|
||||
return result
|
||||
}
|
||||
merged := cloneMap(result.Body)
|
||||
for key, value := range rewritten {
|
||||
merged[key] = value
|
||||
}
|
||||
if !mapsEqual(before, merged) {
|
||||
result.Log.recordScriptChange("CustomPreprocessScript", "rewrite", "$", before, merged, "platform custom preprocess script returned parameter updates")
|
||||
}
|
||||
result.Body = merged
|
||||
result.Log.Output = cloneMap(merged)
|
||||
result.Log.Changed = len(result.Log.Changes) > 0
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *Service) scriptContext(candidate store.RuntimeModelCandidate, modelType string, payload map[string]any, extra map[string]any) map[string]any {
|
||||
getTaskURL := platformConfigString(candidate.PlatformConfig, "getTaskURL", "get_task_url")
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(candidate.BaseURL), "/")
|
||||
env := cloneMap(candidate.PlatformConfig)
|
||||
context := map[string]any{
|
||||
"__easyaiScriptContext": true,
|
||||
"baseURL": baseURL,
|
||||
"getTaskURL": getTaskURL,
|
||||
"authValues": cloneMap(candidate.Credentials),
|
||||
"headers": map[string]any{},
|
||||
"payload": cloneMap(payload),
|
||||
"type": modelType,
|
||||
"options": map[string]any{
|
||||
"model": candidate.ModelName,
|
||||
"providerModelName": candidate.ProviderModelName,
|
||||
"platformId": candidate.PlatformID,
|
||||
"platformModelId": candidate.PlatformModelID,
|
||||
"canonicalModelKey": candidate.CanonicalModelKey,
|
||||
"sourceProviderCode": candidate.Provider,
|
||||
},
|
||||
"env": env,
|
||||
"candidate": preprocessingModelSnapshot(candidate),
|
||||
}
|
||||
context["createRequestURL"] = func(path string, base ...string) string {
|
||||
selectedBase := baseURL
|
||||
if len(base) > 0 && strings.TrimSpace(base[0]) != "" {
|
||||
selectedBase = strings.TrimRight(strings.TrimSpace(base[0]), "/")
|
||||
}
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
return selectedBase + "/" + strings.TrimLeft(path, "/")
|
||||
}
|
||||
context["creatRequestURL"] = context["createRequestURL"]
|
||||
context["resolveGetTaskURL"] = func(taskID string) string {
|
||||
return resolveTaskURLTemplate(getTaskURL, taskID, "")
|
||||
}
|
||||
for key, value := range extra {
|
||||
context[key] = value
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
func preprocessingModelSnapshot(candidate store.RuntimeModelCandidate) map[string]any {
|
||||
return map[string]any{
|
||||
"modelName": candidate.ModelName,
|
||||
"modelAlias": candidate.ModelAlias,
|
||||
"providerModelName": candidate.ProviderModelName,
|
||||
"provider": candidate.Provider,
|
||||
"platformId": candidate.PlatformID,
|
||||
"platformModelId": candidate.PlatformModelID,
|
||||
"capabilities": cloneMap(candidate.Capabilities),
|
||||
}
|
||||
}
|
||||
|
||||
func (log *parameterPreprocessingLog) recordScriptChange(processor string, action string, path string, before any, after any, reason string) {
|
||||
if log == nil {
|
||||
return
|
||||
}
|
||||
log.Changes = append(log.Changes, parameterPreprocessChange{
|
||||
Processor: processor,
|
||||
Action: action,
|
||||
Path: path,
|
||||
Before: cloneAny(before),
|
||||
After: cloneAny(after),
|
||||
Reason: reason,
|
||||
})
|
||||
}
|
||||
|
||||
func platformConfigString(config map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value := strings.TrimSpace(fmt.Sprint(config[key])); value != "" && value != "<nil>" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func platformConfigBool(config map[string]any, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
switch value := config[key].(type) {
|
||||
case bool:
|
||||
return value
|
||||
case string:
|
||||
return strings.EqualFold(strings.TrimSpace(value), "true")
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveTaskURLTemplate(template string, upstreamTaskID string, taskID string) string {
|
||||
out := strings.TrimSpace(template)
|
||||
replacements := [][2]string{
|
||||
{"${upstream_task_id}", upstreamTaskID},
|
||||
{"{{upstream_task_id}}", upstreamTaskID},
|
||||
{"{upstream_task_id}", upstreamTaskID},
|
||||
{"${task_id}", taskID},
|
||||
{"{{task_id}}", taskID},
|
||||
{"{task_id}", taskID},
|
||||
{"${taskId}", upstreamTaskID},
|
||||
{"${taskID}", upstreamTaskID},
|
||||
{"{{taskId}}", upstreamTaskID},
|
||||
{"{{taskID}}", upstreamTaskID},
|
||||
{"{taskId}", upstreamTaskID},
|
||||
{"{taskID}", upstreamTaskID},
|
||||
}
|
||||
for _, replacement := range replacements {
|
||||
out = strings.ReplaceAll(out, replacement[0], replacement[1])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapsEqual(left map[string]any, right map[string]any) bool {
|
||||
return reflect.DeepEqual(left, right)
|
||||
}
|
||||
64
apps/api/internal/runner/param_processor_script_test.go
Normal file
64
apps/api/internal/runner/param_processor_script_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestPreprocessRequestWithCustomScript(t *testing.T) {
|
||||
service := &Service{scriptExecutor: &scriptengine.Executor{}}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
Provider: "universal",
|
||||
ModelName: "image-model",
|
||||
ModelType: "image_generate",
|
||||
Capabilities: map[string]any{
|
||||
"image_generate": map[string]any{"max_output_images": 4},
|
||||
},
|
||||
PlatformConfig: map[string]any{
|
||||
"customPreprocessScript": `(params, type, context) => {
|
||||
return { prompt: params.prompt + "-" + type, n: 2, provider: context.candidate.provider };
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 8}, candidate)
|
||||
if result.Err != nil {
|
||||
t.Fatalf("unexpected preprocess error: %v", result.Err)
|
||||
}
|
||||
if result.Body["prompt"] != "hello-image_generate" || result.Body["n"].(float64) != 2 {
|
||||
t.Fatalf("unexpected body: %#v", result.Body)
|
||||
}
|
||||
if !result.Log.Changed || len(result.Log.Changes) == 0 {
|
||||
t.Fatalf("expected script change in log: %#v", result.Log)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreprocessRequestSkipParamNormalizationSkipsCustomScript(t *testing.T) {
|
||||
service := &Service{scriptExecutor: &scriptengine.Executor{}}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelName: "image-model",
|
||||
ModelType: "image_generate",
|
||||
Provider: "universal",
|
||||
Capabilities: map[string]any{
|
||||
"image_generate": map[string]any{"max_output_images": 1},
|
||||
},
|
||||
PlatformConfig: map[string]any{
|
||||
"skipParamNormalization": true,
|
||||
"customPreprocessScript": `(params) => ({ prompt: "changed", n: 1 })`,
|
||||
},
|
||||
}
|
||||
|
||||
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 9}, candidate)
|
||||
if result.Err != nil {
|
||||
t.Fatalf("unexpected preprocess error: %v", result.Err)
|
||||
}
|
||||
if result.Body["prompt"] != "hello" || result.Body["n"].(int) != 9 {
|
||||
t.Fatalf("skip should keep raw body, got %#v", result.Body)
|
||||
}
|
||||
if result.Log.Changed || len(result.Log.Changes) != 0 {
|
||||
t.Fatalf("skip should not record changes: %#v", result.Log)
|
||||
}
|
||||
}
|
||||
@ -12,18 +12,20 @@ import (
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/riverqueue/river"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg config.Config
|
||||
store *store.Store
|
||||
logger *slog.Logger
|
||||
clients map[string]clients.Client
|
||||
httpClients *httpClientCache
|
||||
riverClient *river.Client[pgx.Tx]
|
||||
cfg config.Config
|
||||
store *store.Store
|
||||
logger *slog.Logger
|
||||
clients map[string]clients.Client
|
||||
scriptExecutor *scriptengine.Executor
|
||||
httpClients *httpClientCache
|
||||
riverClient *river.Client[pgx.Tx]
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
@ -47,17 +49,29 @@ func (e *TaskQueuedError) Is(target error) bool {
|
||||
|
||||
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
|
||||
httpClients := newHTTPClientCache()
|
||||
scriptExecutor := &scriptengine.Executor{Logger: logger}
|
||||
return &Service{
|
||||
cfg: cfg,
|
||||
store: db,
|
||||
logger: logger,
|
||||
cfg: cfg,
|
||||
store: db,
|
||||
logger: logger,
|
||||
scriptExecutor: scriptExecutor,
|
||||
clients: map[string]clients.Client{
|
||||
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
|
||||
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
|
||||
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
|
||||
"keling": clients.KelingClient{HTTPClient: httpClients.none},
|
||||
"kling": clients.KelingClient{HTTPClient: httpClients.none},
|
||||
"simulation": clients.SimulationClient{},
|
||||
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
|
||||
"aliyun-bailian": clients.AliyunBailianClient{HTTPClient: httpClients.none},
|
||||
"blackforest": clients.BlackforestClient{HTTPClient: httpClients.none},
|
||||
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
|
||||
"jimeng": clients.JimengClient{HTTPClient: httpClients.none},
|
||||
"midjourney": clients.MidjourneyClient{HTTPClient: httpClients.none},
|
||||
"minimax": clients.MinimaxClient{HTTPClient: httpClients.none},
|
||||
"newapi": clients.NewAPIClient{HTTPClient: httpClients.none},
|
||||
"tencent-hunyuan-image": clients.HunyuanImageClient{HTTPClient: httpClients.none},
|
||||
"tencent-hunyuan-video": clients.HunyuanVideoClient{HTTPClient: httpClients.none},
|
||||
"vidu": clients.ViduClient{HTTPClient: httpClients.none},
|
||||
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
|
||||
"keling": clients.KelingClient{HTTPClient: httpClients.none},
|
||||
"kling": clients.KelingClient{HTTPClient: httpClients.none},
|
||||
"universal": clients.UniversalClient{HTTPClient: httpClients.none, ScriptExecutor: scriptExecutor},
|
||||
"simulation": clients.SimulationClient{},
|
||||
},
|
||||
httpClients: httpClients,
|
||||
}
|
||||
@ -147,7 +161,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
attemptNo := task.AttemptCount
|
||||
var firstPreprocessing parameterPreprocessingLog
|
||||
if len(candidates) > 0 {
|
||||
preprocessing := preprocessRequestWithLog(task.Kind, body, candidates[0])
|
||||
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0])
|
||||
firstCandidateBody = preprocessing.Body
|
||||
firstPreprocessing = preprocessing.Log
|
||||
normalizedModelType = candidates[0].ModelType
|
||||
@ -225,7 +239,7 @@ candidatesLoop:
|
||||
var candidateErr error
|
||||
for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ {
|
||||
nextAttemptNo := attemptNo + 1
|
||||
preprocessing := preprocessRequestWithLog(task.Kind, body, candidate)
|
||||
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidate)
|
||||
preprocessingLog := preprocessing.Log
|
||||
lastPreprocessing = &preprocessingLog
|
||||
if preprocessing.Err != nil {
|
||||
@ -1090,8 +1104,13 @@ func parameterPreprocessClientError(err error) *clients.ClientError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
code := "invalid_parameter"
|
||||
var coded interface{ ErrorCode() string }
|
||||
if errors.As(err, &coded) && strings.TrimSpace(coded.ErrorCode()) != "" {
|
||||
code = coded.ErrorCode()
|
||||
}
|
||||
return &clients.ClientError{
|
||||
Code: "invalid_parameter",
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
StatusCode: 400,
|
||||
Retryable: false,
|
||||
|
||||
530
apps/api/internal/script/executor.go
Normal file
530
apps/api/internal/script/executor.go
Normal file
@ -0,0 +1,530 @@
|
||||
package script
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/dop251/goja_nodejs/eventloop"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultTimeout = 30 * time.Second
|
||||
PreprocessTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type Logger interface {
|
||||
Debug(msg string, args ...any)
|
||||
Info(msg string, args ...any)
|
||||
Warn(msg string, args ...any)
|
||||
Error(msg string, args ...any)
|
||||
}
|
||||
|
||||
type Executor struct {
|
||||
HTTPClient *http.Client
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Script string
|
||||
Args []any
|
||||
ContextData map[string]any
|
||||
ScriptName string
|
||||
PreferredEntryNames []string
|
||||
Timeout time.Duration
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if strings.TrimSpace(e.Message) != "" {
|
||||
return e.Message
|
||||
}
|
||||
return e.Code
|
||||
}
|
||||
|
||||
func (e *Error) ErrorCode() string {
|
||||
if e == nil || strings.TrimSpace(e.Code) == "" {
|
||||
return "script_error"
|
||||
}
|
||||
return e.Code
|
||||
}
|
||||
|
||||
type result struct {
|
||||
value any
|
||||
err error
|
||||
}
|
||||
|
||||
var (
|
||||
functionDeclarationPattern = regexp.MustCompile(`(?:^|\n)\s*(?:async\s+)?function\s+([A-Za-z_$][\w$]*)\s*\(`)
|
||||
assignedFunctionPattern = regexp.MustCompile(`(?:^|\n)\s*(?:const|let|var)\s+([A-Za-z_$][\w$]*)\s*=\s*(?:async\s*)?(?:function\b|\([^)]*\)\s*=>|[A-Za-z_$][\w$]*\s*=>)`)
|
||||
)
|
||||
|
||||
func (e Executor) Execute(ctx context.Context, opts Options) (any, error) {
|
||||
scriptText := strings.TrimSpace(opts.Script)
|
||||
if scriptText == "" {
|
||||
return nil, &Error{Code: "script_empty", Message: "script is empty"}
|
||||
}
|
||||
scriptName := strings.TrimSpace(opts.ScriptName)
|
||||
if scriptName == "" {
|
||||
scriptName = "script"
|
||||
}
|
||||
timeout := opts.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
loop := eventloop.NewEventLoop(eventloop.EnableConsole(false))
|
||||
loop.Start()
|
||||
defer loop.Terminate()
|
||||
|
||||
resultCh := make(chan result, 1)
|
||||
var once sync.Once
|
||||
finish := func(value any, err error) {
|
||||
once.Do(func() {
|
||||
resultCh <- result{value: value, err: err}
|
||||
loop.StopNoWait()
|
||||
})
|
||||
}
|
||||
|
||||
ok := loop.RunOnLoop(func(vm *goja.Runtime) {
|
||||
e.installRuntime(ctx, loop, vm, opts.HTTPClient, scriptName)
|
||||
for key, value := range opts.ContextData {
|
||||
_ = vm.Set(key, value)
|
||||
}
|
||||
value, err := e.invoke(vm, scriptText, opts.Args, opts.PreferredEntryNames, scriptName)
|
||||
if err != nil {
|
||||
finish(nil, err)
|
||||
return
|
||||
}
|
||||
e.resolveValue(vm, value, finish)
|
||||
})
|
||||
if !ok {
|
||||
return nil, &Error{Code: "script_runtime_error", Message: "script event loop is not available"}
|
||||
}
|
||||
|
||||
select {
|
||||
case out := <-resultCh:
|
||||
if out.err != nil {
|
||||
return nil, out.err
|
||||
}
|
||||
return normalizeExport(out.value), nil
|
||||
case <-ctx.Done():
|
||||
loop.Terminate()
|
||||
code := "script_timeout"
|
||||
if errors.Is(ctx.Err(), context.Canceled) {
|
||||
code = "script_cancelled"
|
||||
}
|
||||
return nil, &Error{Code: code, Message: fmt.Sprintf("%s exceeded %s", scriptName, timeout)}
|
||||
}
|
||||
}
|
||||
|
||||
func (e Executor) invoke(vm *goja.Runtime, scriptText string, args []any, preferred []string, scriptName string) (goja.Value, error) {
|
||||
if fnValue, err := vm.RunString("(" + scriptText + ")"); err == nil {
|
||||
if fn, ok := goja.AssertFunction(fnValue); ok {
|
||||
return fn(goja.Undefined(), values(vm, args)...)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := vm.RunString(scriptText); err != nil {
|
||||
return nil, &Error{Code: "script_compile_error", Message: err.Error()}
|
||||
}
|
||||
|
||||
for _, name := range entryCandidates(scriptText, preferred) {
|
||||
fnValue, err := vm.RunString(fmt.Sprintf("(typeof %s === 'function' ? %s : undefined)", name, name))
|
||||
if err != nil || goja.IsUndefined(fnValue) || goja.IsNull(fnValue) {
|
||||
continue
|
||||
}
|
||||
fn, ok := goja.AssertFunction(fnValue)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
return fn(goja.Undefined(), values(vm, args)...)
|
||||
}
|
||||
|
||||
return nil, &Error{Code: "script_entry_missing", Message: fmt.Sprintf("%s must expose an executable function", scriptName)}
|
||||
}
|
||||
|
||||
func (e Executor) resolveValue(vm *goja.Runtime, value goja.Value, finish func(any, error)) {
|
||||
if value == nil {
|
||||
finish(nil, nil)
|
||||
return
|
||||
}
|
||||
if promise, ok := value.Export().(*goja.Promise); ok {
|
||||
switch promise.State() {
|
||||
case goja.PromiseStateFulfilled:
|
||||
finish(exportValue(promise.Result()), nil)
|
||||
case goja.PromiseStateRejected:
|
||||
finish(nil, &Error{Code: "script_error", Message: stringify(promise.Result())})
|
||||
default:
|
||||
obj := value.ToObject(vm)
|
||||
thenFn, ok := goja.AssertFunction(obj.Get("then"))
|
||||
if !ok {
|
||||
finish(nil, &Error{Code: "script_error", Message: "promise.then is not callable"})
|
||||
return
|
||||
}
|
||||
onResolve := func(call goja.FunctionCall) goja.Value {
|
||||
finish(exportValue(call.Argument(0)), nil)
|
||||
return goja.Undefined()
|
||||
}
|
||||
onReject := func(call goja.FunctionCall) goja.Value {
|
||||
finish(nil, &Error{Code: "script_error", Message: stringify(call.Argument(0))})
|
||||
return goja.Undefined()
|
||||
}
|
||||
_, _ = thenFn(obj, vm.ToValue(onResolve), vm.ToValue(onReject))
|
||||
}
|
||||
return
|
||||
}
|
||||
finish(exportValue(value), nil)
|
||||
}
|
||||
|
||||
func (e Executor) installRuntime(ctx context.Context, loop *eventloop.EventLoop, vm *goja.Runtime, client *http.Client, scriptName string) {
|
||||
vm.SetFieldNameMapper(goja.TagFieldNameMapper("json", true))
|
||||
e.installConsole(vm, scriptName)
|
||||
e.installHTTP(ctx, loop, vm, firstHTTPClient(client, e.HTTPClient), scriptName)
|
||||
_ = vm.Set("FormData", formDataConstructor(vm))
|
||||
_, _ = vm.RunString(`
|
||||
function __easyaiGotRequest(method, url, options) {
|
||||
return {
|
||||
json: function() { return __easyaiHTTP(method, url, options || {}).then(function(resp) { return resp.json(); }); },
|
||||
text: function() { return __easyaiHTTP(method, url, options || {}).then(function(resp) { return resp.text(); }); }
|
||||
};
|
||||
}
|
||||
var got = {
|
||||
get: function(url, options) { return __easyaiGotRequest("GET", url, options); },
|
||||
post: function(url, options) { return __easyaiGotRequest("POST", url, options); },
|
||||
put: function(url, options) { return __easyaiGotRequest("PUT", url, options); },
|
||||
patch: function(url, options) { return __easyaiGotRequest("PATCH", url, options); },
|
||||
delete: function(url, options) { return __easyaiGotRequest("DELETE", url, options); },
|
||||
extend: function() { return this; }
|
||||
};
|
||||
function fetch(url, options) {
|
||||
options = options || {};
|
||||
return __easyaiHTTP(options.method || "GET", url, options);
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
func (e Executor) installConsole(vm *goja.Runtime, scriptName string) {
|
||||
log := func(level string, args ...any) {
|
||||
if e.Logger == nil {
|
||||
return
|
||||
}
|
||||
values := make([]any, 0, len(args)+1)
|
||||
values = append(values, "script", scriptName)
|
||||
values = append(values, args...)
|
||||
switch level {
|
||||
case "error":
|
||||
e.Logger.Error("script console", values...)
|
||||
case "warn":
|
||||
e.Logger.Warn("script console", values...)
|
||||
case "info":
|
||||
e.Logger.Info("script console", values...)
|
||||
default:
|
||||
e.Logger.Debug("script console", values...)
|
||||
}
|
||||
}
|
||||
_ = vm.Set("console", map[string]any{
|
||||
"log": func(args ...any) { log("debug", args...) },
|
||||
"debug": func(args ...any) { log("debug", args...) },
|
||||
"info": func(args ...any) { log("info", args...) },
|
||||
"warn": func(args ...any) { log("warn", args...) },
|
||||
"error": func(args ...any) { log("error", args...) },
|
||||
})
|
||||
}
|
||||
|
||||
func (e Executor) installHTTP(ctx context.Context, loop *eventloop.EventLoop, vm *goja.Runtime, client *http.Client, scriptName string) {
|
||||
_ = vm.Set("__easyaiHTTP", func(call goja.FunctionCall) goja.Value {
|
||||
method := strings.ToUpper(strings.TrimSpace(call.Argument(0).String()))
|
||||
if method == "" {
|
||||
method = http.MethodGet
|
||||
}
|
||||
url := strings.TrimSpace(call.Argument(1).String())
|
||||
options := exportMap(call.Argument(2))
|
||||
promise, resolve, reject := vm.NewPromise()
|
||||
go func() {
|
||||
response, err := doHTTPRequest(ctx, client, method, url, options)
|
||||
loop.RunOnLoop(func(runtime *goja.Runtime) {
|
||||
if err != nil {
|
||||
_ = reject(err.Error())
|
||||
return
|
||||
}
|
||||
_ = resolve(httpResponseObject(runtime, response))
|
||||
})
|
||||
}()
|
||||
return vm.ToValue(promise)
|
||||
})
|
||||
}
|
||||
|
||||
func doHTTPRequest(ctx context.Context, client *http.Client, method string, url string, options map[string]any) (httpScriptResponse, error) {
|
||||
if strings.TrimSpace(url) == "" {
|
||||
return httpScriptResponse{}, errors.New("url is required")
|
||||
}
|
||||
var body io.Reader
|
||||
headers := map[string]string{}
|
||||
if rawHeaders, ok := options["headers"].(map[string]any); ok {
|
||||
for key, value := range rawHeaders {
|
||||
if text := strings.TrimSpace(fmt.Sprint(value)); text != "" {
|
||||
headers[key] = text
|
||||
}
|
||||
}
|
||||
}
|
||||
if jsonBody, ok := options["json"]; ok {
|
||||
raw, err := json.Marshal(jsonBody)
|
||||
if err != nil {
|
||||
return httpScriptResponse{}, err
|
||||
}
|
||||
body = bytes.NewReader(raw)
|
||||
if _, ok := headers["Content-Type"]; !ok {
|
||||
headers["Content-Type"] = "application/json"
|
||||
}
|
||||
} else if rawBody, ok := options["body"]; ok {
|
||||
body, headers = requestBody(rawBody, headers)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return httpScriptResponse{}, err
|
||||
}
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return httpScriptResponse{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024*1024))
|
||||
out := httpScriptResponse{
|
||||
Status: resp.Status,
|
||||
StatusCode: resp.StatusCode,
|
||||
OK: resp.StatusCode >= 200 && resp.StatusCode < 300,
|
||||
Headers: map[string]any{},
|
||||
Body: string(raw),
|
||||
}
|
||||
for key, values := range resp.Header {
|
||||
if len(values) == 1 {
|
||||
out.Headers[key] = values[0]
|
||||
} else {
|
||||
out.Headers[key] = values
|
||||
}
|
||||
}
|
||||
if len(raw) > 0 {
|
||||
var parsed any
|
||||
if json.Unmarshal(raw, &parsed) == nil {
|
||||
out.JSON = parsed
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type httpScriptResponse struct {
|
||||
Status string
|
||||
StatusCode int
|
||||
OK bool
|
||||
Headers map[string]any
|
||||
Body string
|
||||
JSON any
|
||||
}
|
||||
|
||||
func httpResponseObject(vm *goja.Runtime, response httpScriptResponse) map[string]any {
|
||||
return map[string]any{
|
||||
"status": response.StatusCode,
|
||||
"statusCode": response.StatusCode,
|
||||
"ok": response.OK,
|
||||
"headers": response.Headers,
|
||||
"text": func() string {
|
||||
return response.Body
|
||||
},
|
||||
"json": func() any {
|
||||
if response.JSON != nil {
|
||||
return response.JSON
|
||||
}
|
||||
var parsed any
|
||||
if json.Unmarshal([]byte(response.Body), &parsed) == nil {
|
||||
return parsed
|
||||
}
|
||||
panic(vm.NewTypeError("response body is not valid JSON"))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func requestBody(value any, headers map[string]string) (io.Reader, map[string]string) {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.NewReader(typed), headers
|
||||
case []byte:
|
||||
return bytes.NewReader(typed), headers
|
||||
case map[string]any:
|
||||
if fields, ok := typed["__easyaiFormData"].([]any); ok {
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
for _, rawField := range fields {
|
||||
field, ok := rawField.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
_ = writer.WriteField(strings.TrimSpace(fmt.Sprint(field["name"])), fmt.Sprint(field["value"]))
|
||||
}
|
||||
_ = writer.Close()
|
||||
headers["Content-Type"] = writer.FormDataContentType()
|
||||
return &buf, headers
|
||||
}
|
||||
raw, _ := json.Marshal(typed)
|
||||
headers["Content-Type"] = "application/json"
|
||||
return bytes.NewReader(raw), headers
|
||||
default:
|
||||
raw, _ := json.Marshal(typed)
|
||||
headers["Content-Type"] = "application/json"
|
||||
return bytes.NewReader(raw), headers
|
||||
}
|
||||
}
|
||||
|
||||
func formDataConstructor(vm *goja.Runtime) func(goja.ConstructorCall) *goja.Object {
|
||||
return func(call goja.ConstructorCall) *goja.Object {
|
||||
obj := call.This
|
||||
_ = obj.Set("__easyaiFormData", []any{})
|
||||
_ = obj.Set("append", func(name string, value any) {
|
||||
fields := exportSlice(obj.Get("__easyaiFormData"))
|
||||
fields = append(fields, map[string]any{"name": name, "value": value})
|
||||
_ = obj.Set("__easyaiFormData", fields)
|
||||
})
|
||||
return obj
|
||||
}
|
||||
}
|
||||
|
||||
func entryCandidates(scriptText string, preferred []string) []string {
|
||||
values := make([]string, 0, len(preferred)+4)
|
||||
appendUnique := func(value string) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
for _, existing := range values {
|
||||
if existing == value {
|
||||
return
|
||||
}
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
for _, value := range preferred {
|
||||
appendUnique(value)
|
||||
}
|
||||
for _, match := range functionDeclarationPattern.FindAllStringSubmatch(scriptText, -1) {
|
||||
appendUnique(match[1])
|
||||
}
|
||||
for _, match := range assignedFunctionPattern.FindAllStringSubmatch(scriptText, -1) {
|
||||
appendUnique(match[1])
|
||||
}
|
||||
appendUnique("main")
|
||||
appendUnique("handler")
|
||||
return values
|
||||
}
|
||||
|
||||
func values(vm *goja.Runtime, input []any) []goja.Value {
|
||||
out := make([]goja.Value, 0, len(input))
|
||||
for _, item := range input {
|
||||
out = append(out, toValue(vm, item))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toValue(vm *goja.Runtime, item any) goja.Value {
|
||||
if values, ok := item.(map[string]any); ok {
|
||||
copied := map[string]any{}
|
||||
for key, value := range values {
|
||||
if key == "__easyaiScriptContext" {
|
||||
continue
|
||||
}
|
||||
copied[key] = value
|
||||
}
|
||||
obj := vm.ToValue(copied).ToObject(vm)
|
||||
if marker, _ := values["__easyaiScriptContext"].(bool); marker {
|
||||
_ = obj.Set("got", vm.Get("got"))
|
||||
_ = obj.Set("fetch", vm.Get("fetch"))
|
||||
_ = obj.Set("FormData", vm.Get("FormData"))
|
||||
}
|
||||
return obj
|
||||
}
|
||||
return vm.ToValue(item)
|
||||
}
|
||||
|
||||
func exportValue(value goja.Value) any {
|
||||
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||
return nil
|
||||
}
|
||||
return value.Export()
|
||||
}
|
||||
|
||||
func exportMap(value goja.Value) map[string]any {
|
||||
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||
return map[string]any{}
|
||||
}
|
||||
if typed, ok := normalizeExport(value.Export()).(map[string]any); ok {
|
||||
return typed
|
||||
}
|
||||
return map[string]any{}
|
||||
}
|
||||
|
||||
func exportSlice(value goja.Value) []any {
|
||||
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||
return []any{}
|
||||
}
|
||||
if typed, ok := normalizeExport(value.Export()).([]any); ok {
|
||||
return typed
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
|
||||
func normalizeExport(value any) any {
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return value
|
||||
}
|
||||
var out any
|
||||
if json.Unmarshal(raw, &out) != nil {
|
||||
return value
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func firstHTTPClient(values ...*http.Client) *http.Client {
|
||||
for _, value := range values {
|
||||
if value != nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return http.DefaultClient
|
||||
}
|
||||
|
||||
func stringify(value goja.Value) string {
|
||||
if value == nil || goja.IsUndefined(value) || goja.IsNull(value) {
|
||||
return "script rejected"
|
||||
}
|
||||
if exported, ok := normalizeExport(value.Export()).(map[string]any); ok {
|
||||
for _, key := range []string{"message", "error", "error_message"} {
|
||||
if message := strings.TrimSpace(fmt.Sprint(exported[key])); message != "" && message != "<nil>" {
|
||||
return message
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(value.String())
|
||||
}
|
||||
116
apps/api/internal/script/executor_test.go
Normal file
116
apps/api/internal/script/executor_test.go
Normal file
@ -0,0 +1,116 @@
|
||||
package script
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExecutorRunsFunctionExpression(t *testing.T) {
|
||||
out, err := (Executor{}).Execute(context.Background(), Options{
|
||||
Script: `(params) => ({ prompt: params.prompt.toUpperCase(), n: 2 })`,
|
||||
Args: []any{map[string]any{"prompt": "hello"}},
|
||||
ScriptName: "custom_preprocess_script",
|
||||
Timeout: time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("execute failed: %v", err)
|
||||
}
|
||||
result := out.(map[string]any)
|
||||
if result["prompt"] != "HELLO" || result["n"].(float64) != 2 {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorSelectsPreferredEntry(t *testing.T) {
|
||||
out, err := (Executor{}).Execute(context.Background(), Options{
|
||||
Script: `
|
||||
function helper() { return { wrong: true }; }
|
||||
async function submitTask(payload, context) {
|
||||
return { status: "submitted", task_id: payload.id, baseURL: context.baseURL };
|
||||
}
|
||||
`,
|
||||
Args: []any{map[string]any{"id": "task-1"}, map[string]any{"baseURL": "https://example.test"}},
|
||||
ContextData: map[string]any{"baseURL": "https://example.test"},
|
||||
PreferredEntryNames: []string{"submitTask", "submit"},
|
||||
ScriptName: "custom_submit_script:video_generate",
|
||||
Timeout: time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("execute failed: %v", err)
|
||||
}
|
||||
result := out.(map[string]any)
|
||||
if result["task_id"] != "task-1" || result["baseURL"] != "https://example.test" {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorGotJSONHelper(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected method: %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test" {
|
||||
t.Fatalf("missing authorization header")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"success","task_id":"remote-1"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
out, err := (Executor{}).Execute(context.Background(), Options{
|
||||
Script: `
|
||||
async function submitTask(payload, context) {
|
||||
return await got.post(context.baseURL + "/tasks", {
|
||||
headers: { Authorization: "Bearer " + context.authValues.apiKey },
|
||||
json: payload
|
||||
}).json();
|
||||
}
|
||||
`,
|
||||
Args: []any{map[string]any{"prompt": "hello"}, map[string]any{"baseURL": server.URL, "authValues": map[string]any{"apiKey": "test"}}},
|
||||
ContextData: map[string]any{"baseURL": server.URL, "authValues": map[string]any{"apiKey": "test"}},
|
||||
PreferredEntryNames: []string{"submitTask"},
|
||||
ScriptName: "custom_submit_script:image_generate",
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("execute failed: %v", err)
|
||||
}
|
||||
result := out.(map[string]any)
|
||||
if result["task_id"] != "remote-1" {
|
||||
t.Fatalf("unexpected result: %#v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorTimeout(t *testing.T) {
|
||||
_, err := (Executor{}).Execute(context.Background(), Options{
|
||||
Script: `async function main() { await new Promise((resolve) => setTimeout(resolve, 200)); return true; }`,
|
||||
ScriptName: "custom_poll_script",
|
||||
Timeout: 25 * time.Millisecond,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout")
|
||||
}
|
||||
scriptErr, ok := err.(*Error)
|
||||
if !ok || scriptErr.Code != "script_timeout" {
|
||||
t.Fatalf("expected script_timeout, got %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecutorRejectedPromiseMessage(t *testing.T) {
|
||||
_, err := (Executor{}).Execute(context.Background(), Options{
|
||||
Script: `async function main() { throw new Error("boom"); }`,
|
||||
ScriptName: "custom_submit_script",
|
||||
Timeout: time.Second,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected rejection")
|
||||
}
|
||||
scriptErr, ok := err.(*Error)
|
||||
if !ok || scriptErr.Code != "script_error" || !strings.Contains(scriptErr.Message, "boom") {
|
||||
t.Fatalf("expected script_error with boom, got %#v", err)
|
||||
}
|
||||
}
|
||||
19
apps/api/migrations/0039_exclude_easyai_media_catalog.sql
Normal file
19
apps/api/migrations/0039_exclude_easyai_media_catalog.sql
Normal file
@ -0,0 +1,19 @@
|
||||
-- EasyAI/server-main is intentionally not migrated as an AI Gateway runtime
|
||||
-- provider. Keep its historical catalog rows for traceability, but hide them
|
||||
-- from fresh admin selection and mark the exclusion reason explicitly.
|
||||
UPDATE base_model_catalog
|
||||
SET status = 'deprecated',
|
||||
metadata = COALESCE(metadata, '{}'::jsonb) || jsonb_build_object(
|
||||
'selectable', false,
|
||||
'migrationExcludedReason', 'excluded from AI Gateway media runtime migration to avoid gateway-to-server-main loopback',
|
||||
'migrationExcludedAt', '0039_exclude_easyai_media_catalog'
|
||||
)
|
||||
WHERE provider_key = 'easyai'
|
||||
AND model_type ?| ARRAY[
|
||||
'image_generate',
|
||||
'image_edit',
|
||||
'video_generate',
|
||||
'image_to_video',
|
||||
'omni_video',
|
||||
'video_edit'
|
||||
];
|
||||
34
docs/media-client-migration.md
Normal file
34
docs/media-client-migration.md
Normal file
@ -0,0 +1,34 @@
|
||||
# Media Client Migration Status
|
||||
|
||||
This document tracks the server-main media runtime migration into the AI Gateway.
|
||||
|
||||
## Runtime Scope
|
||||
|
||||
- Included model types: `image_generate`, `image_edit`, `video_generate`, `image_to_video`, `omni_video`, `video_edit`.
|
||||
- Excluded provider: `easyai`, because routing AI Gateway media tasks back into server-main would create a loopback dependency.
|
||||
- Universal custom scripts are supported through `integration_platforms.config`:
|
||||
- `customPreprocessScript`
|
||||
- `customGetParamsScript`
|
||||
- `customSubmitScript`
|
||||
- `customPollScript`
|
||||
- `getTaskURL`
|
||||
- `skipParamNormalization`
|
||||
|
||||
## Migrated Clients
|
||||
|
||||
- `universal`: custom preprocess/get params/submit/poll scripts, default submit/poll, remote task resume.
|
||||
- `jimeng`: async submit/poll skeleton with Jimeng task id and status mapping.
|
||||
- `blackforest`: submit with `x-key`, `polling_url` polling, image result normalization.
|
||||
- `tencent-hunyuan-image`: Tencent-style `Response.JobId`/`Response.Status` image task mapping.
|
||||
- `tencent-hunyuan-video`: Tencent-style `Response.JobId`/`Response.Status` video task mapping.
|
||||
- `minimax`: video submit/query task mapping.
|
||||
- `midjourney`: diffusion submit, job polling, original and Aliyun-style status/result mapping.
|
||||
- `vidu`: Token auth, typed submit path, creations polling.
|
||||
- `aliyun-bailian`: video synthesis submit and task polling.
|
||||
- `newapi`: `/videos/generations` submit and task polling.
|
||||
|
||||
## Notes
|
||||
|
||||
- Provider-specific advanced parameter shaping remains isolated inside each client/spec.
|
||||
- Tencent and Jimeng production deployments should configure exact submit/poll paths and credentials in platform config when they differ from the default server-main-compatible paths.
|
||||
- Each migrated client has an `httptest` submit/poll coverage case in `internal/clients`.
|
||||
Loading…
Reference in New Issue
Block a user