package clients import ( "context" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) func TestProviderTaskClientsSubmitAndPoll(t *testing.T) { cases := []struct { name string client Client provider string specType string submitMatch func(*http.Request) bool submitResponse string pollMatch func(*http.Request) bool pollResponse string authHeader string resultURL string }{ { name: "jimeng", client: JimengClient{}, provider: "jimeng", specType: "jimeng", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask" }, submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult" }, pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/jimeng.mp4", }, { name: "blackforest", client: BlackforestClient{}, provider: "blackforest", specType: "blackforest", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model") }, submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" }, pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`, authHeader: "test-key", resultURL: "https://cdn.example/flux.png", }, { name: "hunyuan-image", client: HunyuanImageClient{}, provider: "tencent-hunyuan-image", specType: "tencent-hunyuan-image", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob" }, submitResponse: `{"Response":{"JobId":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob" }, pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/hunyuan.png", }, { name: "hunyuan-video", client: HunyuanVideoClient{}, provider: "tencent-hunyuan-video", specType: "tencent-hunyuan-video", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob" }, submitResponse: `{"Response":{"JobId":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob" }, pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/hunyuan.mp4", }, { name: "minimax", client: MinimaxClient{}, provider: "minimax", specType: "minimax", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" }, submitResponse: `{"task_id":123}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123" }, pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/minimax.mp4", }, { name: "midjourney", client: MidjourneyClient{}, provider: "midjourney", specType: "midjourney", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" }, submitResponse: `{"job_id":"remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" }, pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/mj.png", }, { name: "vidu", client: ViduClient{}, provider: "vidu", specType: "vidu", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" }, submitResponse: `{"task_id":"remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations" }, pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`, authHeader: "Token test-key", resultURL: "https://cdn.example/vidu.mp4", }, { name: "aliyun-bailian", client: AliyunBailianClient{}, provider: "aliyun-bailian", specType: "aliyun-bailian", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis" }, submitResponse: `{"output":{"task_id":"remote-1"}}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" }, pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/aliyun.mp4", }, { name: "newapi", client: NewAPIClient{}, provider: "newapi", specType: "newapi", submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" }, submitResponse: `{"task_id":"remote-1"}`, pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1" }, pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`, authHeader: "Bearer test-key", resultURL: "https://cdn.example/newapi.mp4", }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if tc.authHeader == "test-key" { if r.Header.Get("x-key") != tc.authHeader { t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key")) } } else if r.Header.Get("Authorization") != tc.authHeader { t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization")) } w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-"+tc.name) switch { case tc.submitMatch(r): _, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host))) case tc.pollMatch(r): _, _ = w.Write([]byte(tc.pollResponse)) default: t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String()) } })) defer server.Close() var submittedRemoteTaskID string request := Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: tc.provider, SpecType: tc.specType, BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", ClientID: tc.name + ":test", }, OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error { submittedRemoteTaskID = remoteTaskID if payload["payload"] == nil || payload["submit"] == nil { t.Fatalf("missing remote payload: %#v", payload) } return nil }, } response, err := tc.client.Run(context.Background(), request) if err != nil { t.Fatalf("run failed: %v", err) } data, ok := response.Result["data"].([]any) if !ok || len(data) == 0 { t.Fatalf("missing data: %#v", response.Result) } first, _ := data[0].(map[string]any) if first["url"] != tc.resultURL { t.Fatalf("unexpected result url: %#v", response.Result) } if response.RequestID != "req-"+tc.name { t.Fatalf("unexpected request id: %q", response.RequestID) } if submittedRemoteTaskID == "" { t.Fatalf("expected remote task submission") } }) } } func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) { t.Run("poll failure", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-failed") switch r.URL.Path { case "/videos/generations": _, _ = w.Write([]byte(`{"task_id":"remote-1"}`)) case "/videos/generations/remote-1": _, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`)) default: t.Fatalf("unexpected request: %s", r.URL.String()) } })) defer server.Close() _, err := (NewAPIClient{}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: "newapi", SpecType: "newapi", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", }, }) var clientErr *ClientError if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable { t.Fatalf("expected non-retryable upstream failure, got %#v", err) } if clientErr.RequestID != "req-failed" { t.Fatalf("unexpected request id: %q", clientErr.RequestID) } }) t.Run("submit rate limit", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("x-request-id", "req-rate-limit") w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`)) })) defer server.Close() _, err := (NewAPIClient{}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, Candidate: store.RuntimeModelCandidate{ Provider: "newapi", SpecType: "newapi", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", }, }) var clientErr *ClientError if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" { t.Fatalf("expected retryable rate limit with request id, got %#v", err) } }) } func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { t.Fatal("submit should not run for resumed remote task") } w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`)) })) defer server.Close() response, err := (NewAPIClient{}).Run(context.Background(), Request{ Kind: "videos.generations", ModelType: "video_generate", Model: "alias-model", Body: map[string]any{"model": "alias-model", "prompt": "hello"}, RemoteTaskID: "remote-1", RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}}, Candidate: store.RuntimeModelCandidate{ Provider: "newapi", SpecType: "newapi", BaseURL: server.URL, Credentials: map[string]any{"apiKey": "test-key"}, PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000}, ModelName: "alias-model", ProviderModelName: "provider-model", ModelType: "video_generate", }, }) if err != nil { t.Fatalf("run failed: %v", err) } data := response.Result["data"].([]any) first := data[0].(map[string]any) if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" { t.Fatalf("unexpected response: %#v", response.Result) } }