340 lines
13 KiB
Go
340 lines
13 KiB
Go
package clients
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
func TestProviderTaskClientsSubmitAndPoll(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
client Client
|
|
provider string
|
|
specType string
|
|
submitMatch func(*http.Request) bool
|
|
submitResponse string
|
|
pollMatch func(*http.Request) bool
|
|
pollResponse string
|
|
authHeader string
|
|
resultURL string
|
|
}{
|
|
{
|
|
name: "jimeng",
|
|
client: JimengClient{},
|
|
provider: "jimeng",
|
|
specType: "jimeng",
|
|
submitMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSubmitTask"
|
|
},
|
|
submitResponse: `{"code":10000,"data":{"task_id":"remote-1"}}`,
|
|
pollMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "CVSync2AsyncGetResult"
|
|
},
|
|
pollResponse: `{"code":10000,"data":{"status":"done","video_url":"https://cdn.example/jimeng.mp4"}}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/jimeng.mp4",
|
|
},
|
|
{
|
|
name: "blackforest",
|
|
client: BlackforestClient{},
|
|
provider: "blackforest",
|
|
specType: "blackforest",
|
|
submitMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/provider-model")
|
|
},
|
|
submitResponse: `{"polling_url":"__SERVER__/poll/remote-1"}`,
|
|
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/poll/remote-1" },
|
|
pollResponse: `{"status":"Ready","urls":["https://cdn.example/flux.png"]}`,
|
|
authHeader: "test-key",
|
|
resultURL: "https://cdn.example/flux.png",
|
|
},
|
|
{
|
|
name: "hunyuan-image",
|
|
client: HunyuanImageClient{},
|
|
provider: "tencent-hunyuan-image",
|
|
specType: "tencent-hunyuan-image",
|
|
submitMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitHunyuanImageJob"
|
|
},
|
|
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
|
|
pollMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryHunyuanImageJob"
|
|
},
|
|
pollResponse: `{"Response":{"Status":"DONE","ResultImages":["https://cdn.example/hunyuan.png"]}}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/hunyuan.png",
|
|
},
|
|
{
|
|
name: "hunyuan-video",
|
|
client: HunyuanVideoClient{},
|
|
provider: "tencent-hunyuan-video",
|
|
specType: "tencent-hunyuan-video",
|
|
submitMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodPost && r.URL.Query().Get("Action") == "SubmitTextToVideoJob"
|
|
},
|
|
submitResponse: `{"Response":{"JobId":"remote-1"}}`,
|
|
pollMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodGet && r.URL.Query().Get("Action") == "QueryVideoJob"
|
|
},
|
|
pollResponse: `{"Response":{"Status":"DONE","ResultVideoUrl":"https://cdn.example/hunyuan.mp4"}}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/hunyuan.mp4",
|
|
},
|
|
{
|
|
name: "minimax",
|
|
client: MinimaxClient{},
|
|
provider: "minimax",
|
|
specType: "minimax",
|
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/video_generation" },
|
|
submitResponse: `{"task_id":123}`,
|
|
pollMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodGet && r.URL.Path == "/query/video_generation" && r.URL.Query().Get("task_id") == "123"
|
|
},
|
|
pollResponse: `{"status":"Success","file_id":"file-1","video_url":"https://cdn.example/minimax.mp4"}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/minimax.mp4",
|
|
},
|
|
{
|
|
name: "midjourney",
|
|
client: MidjourneyClient{},
|
|
provider: "midjourney",
|
|
specType: "midjourney",
|
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/diffusion" },
|
|
submitResponse: `{"job_id":"remote-1"}`,
|
|
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/job/remote-1" },
|
|
pollResponse: `{"status":"completed","output":{"image_urls":["https://cdn.example/mj.png"]}}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/mj.png",
|
|
},
|
|
{
|
|
name: "vidu",
|
|
client: ViduClient{},
|
|
provider: "vidu",
|
|
specType: "vidu",
|
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/text2video" },
|
|
submitResponse: `{"task_id":"remote-1"}`,
|
|
pollMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1/creations"
|
|
},
|
|
pollResponse: `{"state":"success","video_url":"https://cdn.example/vidu.mp4"}`,
|
|
authHeader: "Token test-key",
|
|
resultURL: "https://cdn.example/vidu.mp4",
|
|
},
|
|
{
|
|
name: "aliyun-bailian",
|
|
client: AliyunBailianClient{},
|
|
provider: "aliyun-bailian",
|
|
specType: "aliyun-bailian",
|
|
submitMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodPost && r.URL.Path == "/services/aigc/video-generation/video-synthesis"
|
|
},
|
|
submitResponse: `{"output":{"task_id":"remote-1"}}`,
|
|
pollMatch: func(r *http.Request) bool { return r.Method == http.MethodGet && r.URL.Path == "/tasks/remote-1" },
|
|
pollResponse: `{"output":{"task_status":"SUCCEEDED","video_url":"https://cdn.example/aliyun.mp4"}}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/aliyun.mp4",
|
|
},
|
|
{
|
|
name: "newapi",
|
|
client: NewAPIClient{},
|
|
provider: "newapi",
|
|
specType: "newapi",
|
|
submitMatch: func(r *http.Request) bool { return r.Method == http.MethodPost && r.URL.Path == "/videos/generations" },
|
|
submitResponse: `{"task_id":"remote-1"}`,
|
|
pollMatch: func(r *http.Request) bool {
|
|
return r.Method == http.MethodGet && r.URL.Path == "/videos/generations/remote-1"
|
|
},
|
|
pollResponse: `{"status":"SUCCESS","data":{"output":"https://cdn.example/newapi.mp4"}}`,
|
|
authHeader: "Bearer test-key",
|
|
resultURL: "https://cdn.example/newapi.mp4",
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if tc.authHeader == "test-key" {
|
|
if r.Header.Get("x-key") != tc.authHeader {
|
|
t.Fatalf("unexpected x-key header: %q", r.Header.Get("x-key"))
|
|
}
|
|
} else if r.Header.Get("Authorization") != tc.authHeader {
|
|
t.Fatalf("unexpected auth header: %q", r.Header.Get("Authorization"))
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("x-request-id", "req-"+tc.name)
|
|
switch {
|
|
case tc.submitMatch(r):
|
|
_, _ = w.Write([]byte(strings.ReplaceAll(tc.submitResponse, "__SERVER__", "http://"+r.Host)))
|
|
case tc.pollMatch(r):
|
|
_, _ = w.Write([]byte(tc.pollResponse))
|
|
default:
|
|
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.String())
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
var submittedRemoteTaskID string
|
|
request := Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "alias-model",
|
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
|
Candidate: store.RuntimeModelCandidate{
|
|
Provider: tc.provider,
|
|
SpecType: tc.specType,
|
|
BaseURL: server.URL,
|
|
Credentials: map[string]any{"apiKey": "test-key"},
|
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
|
ModelName: "alias-model",
|
|
ProviderModelName: "provider-model",
|
|
ModelType: "video_generate",
|
|
ClientID: tc.name + ":test",
|
|
},
|
|
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
|
submittedRemoteTaskID = remoteTaskID
|
|
if payload["payload"] == nil || payload["submit"] == nil {
|
|
t.Fatalf("missing remote payload: %#v", payload)
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
|
|
response, err := tc.client.Run(context.Background(), request)
|
|
if err != nil {
|
|
t.Fatalf("run failed: %v", err)
|
|
}
|
|
data, ok := response.Result["data"].([]any)
|
|
if !ok || len(data) == 0 {
|
|
t.Fatalf("missing data: %#v", response.Result)
|
|
}
|
|
first, _ := data[0].(map[string]any)
|
|
if first["url"] != tc.resultURL {
|
|
t.Fatalf("unexpected result url: %#v", response.Result)
|
|
}
|
|
if response.RequestID != "req-"+tc.name {
|
|
t.Fatalf("unexpected request id: %q", response.RequestID)
|
|
}
|
|
if submittedRemoteTaskID == "" {
|
|
t.Fatalf("expected remote task submission")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProviderTaskClientFailureAndRetryableErrors(t *testing.T) {
|
|
t.Run("poll failure", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("x-request-id", "req-failed")
|
|
switch r.URL.Path {
|
|
case "/videos/generations":
|
|
_, _ = w.Write([]byte(`{"task_id":"remote-1"}`))
|
|
case "/videos/generations/remote-1":
|
|
_, _ = w.Write([]byte(`{"status":"failed","code":"UPSTREAM_FAILED","message":"provider rejected"}`))
|
|
default:
|
|
t.Fatalf("unexpected request: %s", r.URL.String())
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
_, err := (NewAPIClient{}).Run(context.Background(), Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "alias-model",
|
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
|
Candidate: store.RuntimeModelCandidate{
|
|
Provider: "newapi",
|
|
SpecType: "newapi",
|
|
BaseURL: server.URL,
|
|
Credentials: map[string]any{"apiKey": "test-key"},
|
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
|
ModelName: "alias-model",
|
|
ProviderModelName: "provider-model",
|
|
ModelType: "video_generate",
|
|
},
|
|
})
|
|
var clientErr *ClientError
|
|
if !errors.As(err, &clientErr) || clientErr.Code != "UPSTREAM_FAILED" || clientErr.Retryable {
|
|
t.Fatalf("expected non-retryable upstream failure, got %#v", err)
|
|
}
|
|
if clientErr.RequestID != "req-failed" {
|
|
t.Fatalf("unexpected request id: %q", clientErr.RequestID)
|
|
}
|
|
})
|
|
|
|
t.Run("submit rate limit", func(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("x-request-id", "req-rate-limit")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
_, _ = w.Write([]byte(`{"error":{"message":"slow down"}}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
_, err := (NewAPIClient{}).Run(context.Background(), Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "alias-model",
|
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
|
Candidate: store.RuntimeModelCandidate{
|
|
Provider: "newapi",
|
|
SpecType: "newapi",
|
|
BaseURL: server.URL,
|
|
Credentials: map[string]any{"apiKey": "test-key"},
|
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
|
ModelName: "alias-model",
|
|
ProviderModelName: "provider-model",
|
|
ModelType: "video_generate",
|
|
},
|
|
})
|
|
var clientErr *ClientError
|
|
if !errors.As(err, &clientErr) || !clientErr.Retryable || clientErr.RequestID != "req-rate-limit" {
|
|
t.Fatalf("expected retryable rate limit with request id, got %#v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestProviderTaskClientResumeSkipsSubmit(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodPost {
|
|
t.Fatal("submit should not run for resumed remote task")
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"status":"success","data":{"output":"https://cdn.example/resume.mp4"}}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
response, err := (NewAPIClient{}).Run(context.Background(), Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "alias-model",
|
|
Body: map[string]any{"model": "alias-model", "prompt": "hello"},
|
|
RemoteTaskID: "remote-1",
|
|
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
|
|
Candidate: store.RuntimeModelCandidate{
|
|
Provider: "newapi",
|
|
SpecType: "newapi",
|
|
BaseURL: server.URL,
|
|
Credentials: map[string]any{"apiKey": "test-key"},
|
|
PlatformConfig: map[string]any{"pollIntervalMs": 1, "pollTimeoutMs": 1000},
|
|
ModelName: "alias-model",
|
|
ProviderModelName: "provider-model",
|
|
ModelType: "video_generate",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("run failed: %v", err)
|
|
}
|
|
data := response.Result["data"].([]any)
|
|
first := data[0].(map[string]any)
|
|
if first["url"] != "https://cdn.example/resume.mp4" || response.Result["upstream_task_id"] != "remote-1" {
|
|
t.Fatalf("unexpected response: %#v", response.Result)
|
|
}
|
|
}
|