easyai-ai-gateway/apps/api/internal/clients/provider_task_test.go

340 lines
13 KiB
Go

package clients
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestProviderTaskClientsSubmitAndPoll(t *testing.T) {
cases := []struct {
name string
client Client
provider string
specType string
submitMatch func(*http.Request) bool
submitResponse string
pollMatch func(*http.Request) bool
pollResponse string
authHeader string
resultURL string
}{
{
name: "jimeng",
client: JimengClient{},
provider: "jimeng",
specType: "jimeng",
submitMatch: func(r *http.Request) bool {
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask"
},
submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`,
pollMatch: func(r *http.Request) bool {
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult"
},
pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/jimeng.mp4",
},
{
name: "blackforest",
client: BlackforestClient{},
provider: "blackforest",
specType: "blackforest",
submitMatch: func(r *http.Request) bool {
return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model")
},
submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`,
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" },
pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`,
authHeader: "test-key",
resultURL: "https://cdn.example/flux.png",
},
{
name: "hunyuan-image",
client: HunyuanImageClient{},
provider: "tencent-hunyuan-image",
specType: "tencent-hunyuan-image",
submitMatch: func(r *http.Request) bool {
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob"
},
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
pollMatch: func(r *http.Request) bool {
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob"
},
pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/hunyuan.png",
},
{
name: "hunyuan-video",
client: HunyuanVideoClient{},
provider: "tencent-hunyuan-video",
specType: "tencent-hunyuan-video",
submitMatch: func(r *http.Request) bool {
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob"
},
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
pollMatch: func(r *http.Request) bool {
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob"
},
pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/hunyuan.mp4",
},
{
name: "minimax",
client: MinimaxClient{},
provider: "minimax",
specType: "minimax",
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" },
submitResponse: `{"task_id":123}`,
pollMatch: func(r *http.Request) bool {
return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123"
},
pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/minimax.mp4",
},
{
name: "midjourney",
client: MidjourneyClient{},
provider: "midjourney",
specType: "midjourney",
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" },
submitResponse: `{"job_id":"remote-1"}`,
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" },
pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/mj.png",
},
{
name: "vidu",
client: ViduClient{},
provider: "vidu",
specType: "vidu",
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" },
submitResponse: `{"task_id":"remote-1"}`,
pollMatch: func(r *http.Request) bool {
return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations"
},
pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`,
authHeader: "Token test-key",
resultURL: "https://cdn.example/vidu.mp4",
},
{
name: "aliyun-bailian",
client: AliyunBailianClient{},
provider: "aliyun-bailian",
specType: "aliyun-bailian",
submitMatch: func(r *http.Request) bool {
return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis"
},
submitResponse: `{"output":{"task_id":"remote-1"}}`,
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" },
pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/aliyun.mp4",
},
{
name: "newapi",
client: NewAPIClient{},
provider: "newapi",
specType: "newapi",
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" },
submitResponse: `{"task_id":"remote-1"}`,
pollMatch: func(r *http.Request) bool {
return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1"
},
pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`,
authHeader: "Bearer test-key",
resultURL: "https://cdn.example/newapi.mp4",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if tc.authHeader == "test-key" {
if r.Header.Get("x-key") != tc.authHeader {
t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key"))
}
} else if r.Header.Get("Authorization") != tc.authHeader {
t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization"))
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("x-request-id", "req-"+tc.name)
switch {
case tc.submitMatch(r):
_, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host)))
case tc.pollMatch(r):
_, _ = w.Write([]byte(tc.pollResponse))
default:
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
}
}))
defer server.Close()
var submittedRemoteTaskID string
request := Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "alias-model",
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
Candidate: store.RuntimeModelCandidate{
Provider: tc.provider,
SpecType: tc.specType,
BaseURL: server.URL,
Credentials: map[string]any{"apiKey": "test-key"},
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
ModelName: "alias-model",
ProviderModelName: "provider-model",
ModelType: "video_generate",
ClientID: tc.name + ":test",
},
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
submittedRemoteTaskID = remoteTaskID
if payload["payload"] == nil || payload["submit"] == nil {
t.Fatalf("missing remote payload: %#v", payload)
}
return nil
},
}
response, err := tc.client.Run(context.Background(), request)
if err != nil {
t.Fatalf("run failed: %v", err)
}
data, ok := response.Result["data"].([]any)
if !ok || len(data) == 0 {
t.Fatalf("missing data: %#v", response.Result)
}
first, _ := data[0].(map[string]any)
if first["url"] != tc.resultURL {
t.Fatalf("unexpected result url: %#v", response.Result)
}
if response.RequestID != "req-"+tc.name {
t.Fatalf("unexpected request id: %q", response.RequestID)
}
if submittedRemoteTaskID == "" {
t.Fatalf("expected remote task submission")
}
})
}
}
func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) {
t.Run("poll failure", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("x-request-id", "req-failed")
switch r.URL.Path {
case "/videos/generations":
_, _ = w.Write([]byte(`{"task_id":"remote-1"}`))
case "/videos/generations/remote-1":
_, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`))
default:
t.Fatalf("unexpected request: %s", r.URL.String())
}
}))
defer server.Close()
_, err := (NewAPIClient{}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "alias-model",
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
Candidate: store.RuntimeModelCandidate{
Provider: "newapi",
SpecType: "newapi",
BaseURL: server.URL,
Credentials: map[string]any{"apiKey": "test-key"},
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
ModelName: "alias-model",
ProviderModelName: "provider-model",
ModelType: "video_generate",
},
})
var clientErr *ClientError
if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable {
t.Fatalf("expected non-retryable upstream failure, got %#v", err)
}
if clientErr.RequestID != "req-failed" {
t.Fatalf("unexpected request id: %q", clientErr.RequestID)
}
})
t.Run("submit rate limit", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("x-request-id", "req-rate-limit")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`))
}))
defer server.Close()
_, err := (NewAPIClient{}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "alias-model",
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
Candidate: store.RuntimeModelCandidate{
Provider: "newapi",
SpecType: "newapi",
BaseURL: server.URL,
Credentials: map[string]any{"apiKey": "test-key"},
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
ModelName: "alias-model",
ProviderModelName: "provider-model",
ModelType: "video_generate",
},
})
var clientErr *ClientError
if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" {
t.Fatalf("expected retryable rate limit with request id, got %#v", err)
}
})
}
func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
t.Fatal("submit should not run for resumed remote task")
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`))
}))
defer server.Close()
response, err := (NewAPIClient{}).Run(context.Background(), Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "alias-model",
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
RemoteTaskID: "remote-1",
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
Candidate: store.RuntimeModelCandidate{
Provider: "newapi",
SpecType: "newapi",
BaseURL: server.URL,
Credentials: map[string]any{"apiKey": "test-key"},
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
ModelName: "alias-model",
ProviderModelName: "provider-model",
ModelType: "video_generate",
},
})
if err != nil {
t.Fatalf("run failed: %v", err)
}
data := response.Result["data"].([]any)
first := data[0].(map[string]any)
if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" {
t.Fatalf("unexpected response: %#v", response.Result)
}
}