package clients import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) func TestProviderTaskClientsSubmitAndPoll(t *testing.T) { cases := []struct { name string client Client provider string specType string submitMatch func(*http.Request) bool submitResponse string pollMatch func(*http.Request) bool pollResponse string authHeader string resultURL string }{ { name: "jimeng", client: JimengClient{}, provider: "jimeng", specType: "jimeng", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask" }, submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult" }, pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/jimeng.mp4", }, { name: "blackforest", client: BlackforestClient{}, provider: "blackforest", specType: "blackforest", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model") }, submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" }, pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`, authHeader: "test-key", resultURL: "https://cdn.example/flux.png", }, { name: "hunyuan-image", client: HunyuanImageClient{}, provider: "tencent-hunyuan-image", specType: "tencent-hunyuan-image", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob" }, submitResponse: `{"Response":{"JobId":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob" }, pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/hunyuan.png", }, { name: "hunyuan-video", client: HunyuanVideoClient{}, provider: "tencent-hunyuan-video", specType: "tencent-hunyuan-video", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob" }, submitResponse: `{"Response":{"JobId":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob" }, pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/hunyuan.mp4", }, { name: "minimax", client: MinimaxClient{}, provider: "minimax", specType: "minimax", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" }, submitResponse: `{"task_id":123}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123" }, pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/minimax.mp4", }, { name: "midjourney", client: MidjourneyClient{}, provider: "midjourney", specType: "midjourney", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" }, submitResponse: `{"job_id":"remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" }, pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/mj.png", }, { name: "vidu", client: ViduClient{}, provider: "vidu", specType: "vidu", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" }, submitResponse: `{"task_id":"remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations" }, pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`, authHeader: "Token test-key", resultURL: "https://cdn.example/vidu.mp4", }, { name: "aliyun-bailian", client: AliyunBailianClient{}, provider: "aliyun-bailian", specType: "aliyun-bailian", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis" }, submitResponse: `{"output":{"task_id":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" }, pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/aliyun.mp4", }, { name: "newapi", client: NewAPIClient{}, provider: "newapi", specType: "newapi", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" }, submitResponse: `{"task_id":"remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1" }, pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/newapi.mp4", }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if tc.authHeader == "test-key" { if r.Header.Get("x-key") != tc.authHeader { t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key")) } } else if r.Header.Get("Authorization") != tc.authHeader { t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization")) } w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-"+tc.name) switch { case tc.submitMatch(r): _, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host))) case tc.pollMatch(r): _, _ = w.Write([]byte(tc.pollResponse)) default: t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String()) } })) defer server.Close() var submittedRemoteTaskID string request := Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: tc.provider, SpecType: tc.specType, BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", ClientID: tc.name + ":test", }, OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error { submittedRemoteTaskID = remoteTaskID if payload["payload"] == nil || payload["submit"] == nil { t.Fatalf("missing remote payload: %#v", payload) } return nil }, } response, err := tc.client.Run(context.Background(), request) if err != nil { t.Fatalf("run failed: %v", err) } data, ok := response.Result["data"].([]any) if !ok || len(data) == 0 { t.Fatalf("missing data: %#v", response.Result) } first, _ := data[0].(map[string]any) if first["url"] != tc.resultURL { t.Fatalf("unexpected result url: %#v", response.Result) } if response.RequestID != "req-"+tc.name { t.Fatalf("unexpected request id: %q", response.RequestID) } if submittedRemoteTaskID == "" { t.Fatalf("expected remote task submission") } }) } } func TestMinimaxClientNormalizesHailuo23PayloadAndRetrievesFile(t *testing.T) { var submitted map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer test-key" { t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization")) } w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-minimax") switch { case r.Method == http.MethodPost && r.URL.Path == "/video_generation": if err := json.NewDecoder(r.Body).Decode(&submitted); err != nil { t.Fatalf("decode minimax submit request: %v", err) } _, _ = w.Write([]byte(`{"task_id":"mm-task","base_resp":{"status_code":0,"status_msg":"success"}}`)) case r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "mm-task": _, _ = w.Write([]byte(`{"task_id":"mm-task","status":"Success","file_id":"file-1","base_resp":{"status_code":0,"status_msg":"success"}}`)) case r.Method == http.MethodGet && r.URL.Path == "/files/retrieve" && r.URL.Query().Get("file_id") == "file-1": _, _ = w.Write([]byte(`{"file":{"download_url":"https://cdn.example/minimax-file.mp4"},"base_resp":{"status_code":0,"status_msg":"success"}}`)) default: t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String()) } })) defer server.Close() response, err := (MinimaxClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "image_to_video", Model: "海螺2.3", Body: map[string]any{ "resolution": "720p", "duration": 6, "content": []any{ map[string]any{"type": "text", "text": "camera moves in"}, map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}}, }, }, Candidate: store.RuntimeModelCandidate{ Provider: "minimax", SpecType: "minimax", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "海螺2.3", ProviderModelName: "MiniMax-Hailuo-2.3", ModelType: "image_to_video", }, }) if err != nil { t.Fatalf("run minimax client: %v", err) } if submitted["model"] != "MiniMax-Hailuo-2.3" || submitted["prompt"] != "camera moves in" || submitted["first_frame_image"] != "https://example.com/first.png" { t.Fatalf("unexpected minimax submit payload: %+v", submitted) } if submitted["resolution"] != "768P" { t.Fatalf("hailuo 2.3 720p should be submitted as 768P, got %+v", submitted) } if _, ok := submitted["content"]; ok { t.Fatalf("minimax native request should not include generic content: %+v", submitted) } data, _ := response.Result["data"].([]any) if len(data) != 1 { t.Fatalf("unexpected minimax response data: %+v", response.Result) } first, _ := data[0].(map[string]any) if first["url"] != "https://cdn.example/minimax-file.mp4" { t.Fatalf("unexpected minimax video url: %+v", response.Result) } } func TestSunoClientSubmitsAndPollsAudioGeneration(t *testing.T) { var submitted map[string]any var submittedRemoteTaskID string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Authorization"); got != "Bearer test-key" { t.Fatalf("unexpected auth header: %q", got) } w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-suno") switch { case r.Method == http.MethodPost && r.URL.Path == "/generator/suno": if err := json.NewDecoder(r.Body).Decode(&submitted); err != nil { t.Fatalf("decode suno submit request: %v", err) } _, _ = w.Write([]byte(`{"code":200,"data":"suno-task"}`)) case r.Method == http.MethodGet && r.URL.Path == "/v2/sunoinfo" && r.URL.Query().Get("id") == "suno-task": _, _ = w.Write([]byte(`{"code":200,"data":{"status":"succeeded","result":[{"audio_url":"https://cdn.example/song.mp3"}]}}`)) default: t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String()) } })) defer server.Close() response, err := (SunoClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ Kind: "song.generations", ModelType: "audio_generate", Model: "Suno V5", Body: map[string]any{ "prompt": "city lights", "tags": "pop", "negativeTags": "noise", }, Candidate: store.RuntimeModelCandidate{ Provider: "suno", SpecType: "suno", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ProviderModelName: "chirp-v5-0", ModelType: "audio_generate", }, OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error { submittedRemoteTaskID = remoteTaskID if payload["payload"] == nil || payload["submit"] == nil { t.Fatalf("missing remote payload: %#v", payload) } return nil }, }) if err != nil { t.Fatalf("run suno client: %v", err) } if submittedRemoteTaskID != "suno-task" { t.Fatalf("unexpected remote task id: %q", submittedRemoteTaskID) } if submitted["task"] != "create" || submitted["model"] != "v50" || submitted["prompt"] != "city lights" { t.Fatalf("unexpected suno submit payload: %+v", submitted) } if submitted["customMode"] != false || submitted["makeInstrumental"] != false { t.Fatalf("suno defaults should match main-server style payload: %+v", submitted) } data, _ := response.Result["data"].([]any) if len(data) != 1 { t.Fatalf("unexpected suno response: %+v", response.Result) } first, _ := data[0].(map[string]any) if first["type"] != "audio" || first["url"] != "https://cdn.example/song.mp3" { t.Fatalf("unexpected suno normalized audio item: %+v", first) } if response.RequestID != "req-suno" { t.Fatalf("unexpected request id: %q", response.RequestID) } } func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) { t.Run("submit business failure", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-minimax-invalid") if r.Method != http.MethodPost || r.URL.Path != "/video_generation" { t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String()) } _, _ = w.Write([]byte(`{"base_resp":{"status_code":1008,"status_msg":"invalid resolution"}}`)) })) defer server.Close() _, err := (MinimaxClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "海螺2.3", Body: map[string]any{"model": "海螺2.3", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: "minimax", SpecType: "minimax", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "海螺2.3", ProviderModelName: "MiniMax-Hailuo-2.3", ModelType: "video_generate", }, }) var clientErr *ClientError if !errors.As(err, &clientErr) || clientErr.Code != "1008" || clientErr.Message != "invalid resolution" || clientErr.RequestID != "req-minimax-invalid" { t.Fatalf("expected minimax business failure, got %#v", err) } }) t.Run("poll failure", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-failed") switch r.URL.Path { case "/videos/generations": _, _ = w.Write([]byte(`{"task_id":"remote-1"}`)) case "/videos/generations/remote-1": _, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`)) default: t.Fatalf("unexpected request: %s", r.URL.String()) } })) defer server.Close() _, err := (NewAPIClient{}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: "newapi", SpecType: "newapi", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", }, }) var clientErr *ClientError if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable { t.Fatalf("expected non-retryable upstream failure, got %#v", err) } if clientErr.RequestID != "req-failed" { t.Fatalf("unexpected request id: %q", clientErr.RequestID) } }) t.Run("submit rate limit", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-rate-limit") w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`)) })) defer server.Close() _, err := (NewAPIClient{}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: "newapi", SpecType: "newapi", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", }, }) var clientErr *ClientError if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" { t.Fatalf("expected retryable rate limit with request id, got %#v", err) } }) } func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { t.Fatal("submit should not run for resumed remote task") } w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`)) })) defer server.Close() response, err := (NewAPIClient{}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, RemoteTaskID: "remote-1", RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}}, Candidate: store.RuntimeModelCandidate{ Provider: "newapi", SpecType: "newapi", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", }, }) if err != nil { t.Fatalf("run failed: %v", err) } data := response.Result["data"].([]any) first := data[0].(map[string]any) if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" { t.Fatalf("unexpected response: %#v", response.Result) } }