133 lines
4.4 KiB
Go
133 lines
4.4 KiB
Go
package clients
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
func TestUniversalClientRunsCustomScripts(t *testing.T) {
|
|
request := Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "custom-video",
|
|
Body: map[string]any{"model": "custom-video", "prompt": "hello"},
|
|
Candidate: testUniversalCandidate(map[string]any{
|
|
"customGetParamsScript": map[string]any{
|
|
"video_generate": `async function getGenerateParams(params, context) {
|
|
return { prompt: params.prompt + "-payload", model: context.candidate.providerModelName };
|
|
}`,
|
|
},
|
|
"customSubmitScript": map[string]any{
|
|
"video_generate": `async function submitTask(payload) {
|
|
return { status: "submitted", task_id: "task-" + payload.prompt };
|
|
}`,
|
|
},
|
|
"customPollScript": map[string]any{
|
|
"video_generate": `async function pollTask(taskId) {
|
|
return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/video.mp4" }] };
|
|
}`,
|
|
},
|
|
"pollIntervalMs": 1,
|
|
"pollTimeoutMs": 1000,
|
|
}),
|
|
}
|
|
var submitted string
|
|
request.OnRemoteTaskSubmitted = func(remoteTaskID string, payload map[string]any) error {
|
|
submitted = remoteTaskID
|
|
return nil
|
|
}
|
|
|
|
response, err := (UniversalClient{}).Run(context.Background(), request)
|
|
if err != nil {
|
|
t.Fatalf("run failed: %v", err)
|
|
}
|
|
if submitted != "task-hello-payload" {
|
|
t.Fatalf("unexpected remote task id: %q", submitted)
|
|
}
|
|
if response.Result["upstream_task_id"] != "task-hello-payload" {
|
|
t.Fatalf("unexpected result: %#v", response.Result)
|
|
}
|
|
}
|
|
|
|
func TestUniversalClientDefaultSubmitAndPoll(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("Authorization") != "Bearer test-key" {
|
|
t.Fatalf("missing authorization header")
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch r.URL.Path {
|
|
case "/video/generations":
|
|
_, _ = w.Write([]byte(`{"status":"submitted","task_id":"remote-1"}`))
|
|
case "/tasks/remote-1":
|
|
_, _ = w.Write([]byte(`{"status":"success","data":[{"url":"https://cdn.example/default.mp4"}]}`))
|
|
default:
|
|
t.Fatalf("unexpected path: %s", r.URL.Path)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
request := Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "default-video",
|
|
Body: map[string]any{"model": "default-video", "prompt": "hello"},
|
|
Candidate: testUniversalCandidate(map[string]any{
|
|
"getTaskURL": server.URL + "/tasks/{{task_id}}",
|
|
"pollIntervalMs": 1,
|
|
"pollTimeoutMs": 1000,
|
|
}),
|
|
}
|
|
request.Candidate.BaseURL = server.URL
|
|
|
|
response, err := (UniversalClient{}).Run(context.Background(), request)
|
|
if err != nil {
|
|
t.Fatalf("run failed: %v", err)
|
|
}
|
|
if response.Result["upstream_task_id"] != "remote-1" {
|
|
t.Fatalf("unexpected result: %#v", response.Result)
|
|
}
|
|
}
|
|
|
|
func TestUniversalClientResumeSkipsSubmit(t *testing.T) {
|
|
request := Request{
|
|
Kind: "videos.generations",
|
|
ModelType: "video_generate",
|
|
Model: "resume-video",
|
|
Body: map[string]any{"model": "resume-video", "prompt": "hello"},
|
|
RemoteTaskID: "existing-1",
|
|
RemoteTaskPayload: map[string]any{"payload": map[string]any{"prompt": "old"}},
|
|
Candidate: testUniversalCandidate(map[string]any{
|
|
"customSubmitScript": `async function submitTask() { throw new Error("submit should not run"); }`,
|
|
"customPollScript": `async function pollTask(taskId) { return { status: "success", upstream_task_id: taskId, data: [{ url: "https://cdn.example/resume.mp4" }] }; }`,
|
|
"pollIntervalMs": 1,
|
|
"pollTimeoutMs": 1000,
|
|
}),
|
|
}
|
|
|
|
response, err := (UniversalClient{}).Run(context.Background(), request)
|
|
if err != nil {
|
|
t.Fatalf("run failed: %v", err)
|
|
}
|
|
if response.Result["upstream_task_id"] != "existing-1" {
|
|
t.Fatalf("unexpected result: %#v", response.Result)
|
|
}
|
|
}
|
|
|
|
func testUniversalCandidate(config map[string]any) store.RuntimeModelCandidate {
|
|
return store.RuntimeModelCandidate{
|
|
Provider: "universal",
|
|
SpecType: "universal",
|
|
BaseURL: "https://provider.example",
|
|
Credentials: map[string]any{"apiKey": "test-key"},
|
|
PlatformConfig: config,
|
|
ModelName: "alias-model",
|
|
ProviderModelName: "provider-model",
|
|
ModelType: "video_generate",
|
|
ClientID: "universal:test",
|
|
}
|
|
}
|