easyai-ai-gateway/apps/api/internal/clients/universal_test.go

133 lines
4.4 KiB
Go

package clients
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestUniversalClientRunsCustomScripts(t *testing.T) {
request := Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "custom-video",
Body: map[string]any{"model": "custom-video", "prompt": "hello"},
Candidate: testUniversalCandidate(map[string]any{
"customGetParamsScript": map[string]any{
"video_generate": `async function getGenerateParams(params, context) {
return { prompt: params.prompt + "-payload", model: context.candidate.providerModelName };
}`,
},
"customSubmitScript": map[string]any{
"video_generate": `async function submitTask(payload) {
return { status: "submitted", task_id: "task-" + payload.prompt };
}`,
},
"customPollScript": map[string]any{
"video_generate": `async function pollTask(taskId) {
return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/video.mp4" }] };
}`,
},
"pollIntervalMs": 1,
"pollTimeoutMs": 1000,
}),
}
var submitted string
request.OnRemoteTaskSubmitted = func(remoteTaskID string, payload map[string]any) error {
submitted = remoteTaskID
return nil
}
response, err := (UniversalClient{}).Run(context.Background(), request)
if err != nil {
t.Fatalf("run failed: %v", err)
}
if submitted != "task-hello-payload" {
t.Fatalf("unexpected remote task id: %q", submitted)
}
if response.Result["upstream_task_id"] != "task-hello-payload" {
t.Fatalf("unexpected result: %#v", response.Result)
}
}
func TestUniversalClientDefaultSubmitAndPoll(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer test-key" {
t.Fatalf("missing authorization header")
}
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/video/generations":
_, _ = w.Write([]byte(`{"status":"submitted","task_id":"remote-1"}`))
case "/tasks/remote-1":
_, _ = w.Write([]byte(`{"status":"success","data":[{"url":"https://cdn.example/default.mp4"}]}`))
default:
t.Fatalf("unexpected path: %s", r.URL.Path)
}
}))
defer server.Close()
request := Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "default-video",
Body: map[string]any{"model": "default-video", "prompt": "hello"},
Candidate: testUniversalCandidate(map[string]any{
"getTaskURL": server.URL + "/tasks/{{task_id}}",
"pollIntervalMs": 1,
"pollTimeoutMs": 1000,
}),
}
request.Candidate.BaseURL = server.URL
response, err := (UniversalClient{}).Run(context.Background(), request)
if err != nil {
t.Fatalf("run failed: %v", err)
}
if response.Result["upstream_task_id"] != "remote-1" {
t.Fatalf("unexpected result: %#v", response.Result)
}
}
func TestUniversalClientResumeSkipsSubmit(t *testing.T) {
request := Request{
Kind: "videos.generations",
ModelType: "video_generate",
Model: "resume-video",
Body: map[string]any{"model": "resume-video", "prompt": "hello"},
RemoteTaskID: "existing-1",
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
Candidate: testUniversalCandidate(map[string]any{
"customSubmitScript": `async function submitTask() { throw new Error("submit should not run"); }`,
"customPollScript": `async function pollTask(taskId) { return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/resume.mp4" }] }; }`,
"pollIntervalMs": 1,
"pollTimeoutMs": 1000,
}),
}
response, err := (UniversalClient{}).Run(context.Background(), request)
if err != nil {
t.Fatalf("run failed: %v", err)
}
if response.Result["upstream_task_id"] != "existing-1" {
t.Fatalf("unexpected result: %#v", response.Result)
}
}
func testUniversalCandidate(config map[string]any) store.RuntimeModelCandidate {
return store.RuntimeModelCandidate{
Provider: "universal",
SpecType: "universal",
BaseURL: "https://provider.example",
Credentials: map[string]any{"apiKey": "test-key"},
PlatformConfig: config,
ModelName: "alias-model",
ProviderModelName: "provider-model",
ModelType: "video_generate",
ClientID: "universal:test",
}
}