easyai-ai-gateway/apps/api/internal/runner/param_processor_test.go

208 lines
7.1 KiB
Go

package runner
import (
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestParamProcessorOmniFiltersUnsupportedVideoAndAudioContent(t *testing.T) {
body := map[string]any{
"model": "可灵O1",
"prompt": "edit the source video",
"content": []any{
map[string]any{"type": "text", "text": "edit the source video"},
map[string]any{"type": "video_url", "role": "video_base", "video_url": map[string]any{"url": "https://example.com/base.mp4", "refer_type": "base"}},
map[string]any{"type": "video_url", "role": "reference_video", "video_url": map[string]any{"url": "https://example.com/ref.mp4", "refer_type": "feature"}},
map[string]any{"type": "audio_url", "role": "reference_audio", "audio_url": map[string]any{"url": "https://example.com/ref.mp3"}},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "omni_video",
Capabilities: map[string]any{
"omni_video": map[string]any{
"supported_modes": []any{"video_edit"},
"max_videos": 1,
"input_audio": false,
"max_audios": 0,
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
processed := result.Body
content := contentItems(processed["content"])
if len(content) != 2 {
t.Fatalf("expected text plus one video item, got %+v", content)
}
if stringFromAny(content[1]["role"]) != "video_base" || isAudioContent(content[1]) {
t.Fatalf("unexpected retained content: %+v", content)
}
for _, item := range content {
if isAudioContent(item) || stringFromAny(item["role"]) == "reference_video" {
t.Fatalf("unsupported content was not filtered: %+v", content)
}
}
if !result.Log.Changed || len(result.Log.Changes) < 2 {
t.Fatalf("expected preprocessing log with filtered video and audio changes, got %+v", result.Log)
}
if result.Log.Input["content"] == nil || result.Log.Output["content"] == nil {
t.Fatalf("preprocessing log should keep actual input and converted output: %+v", result.Log)
}
foundAudioReason := false
for _, change := range result.Log.Changes {
if change.Path == "content[3]" && change.CapabilityPath == "capabilities.omni_video.input_audio" {
foundAudioReason = true
break
}
}
if !foundAudioReason {
t.Fatalf("expected audio filtering reason to reference omni_video.input_audio, got %+v", result.Log.Changes)
}
}
func TestParamProcessorOmniFiltersConvenienceReferenceFields(t *testing.T) {
body := map[string]any{
"model": "可灵V3多模态",
"prompt": "text only",
"reference_video": "https://example.com/ref.mp4",
"reference_audio": "https://example.com/ref.mp3",
}
candidate := store.RuntimeModelCandidate{
ModelType: "omni_video",
Capabilities: map[string]any{
"omni_video": map[string]any{
"supported_modes": []any{"text_to_video"},
"max_videos": 0,
"input_audio": false,
"max_audios": 0,
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
processed := result.Body
content := contentItems(processed["content"])
if len(content) != 1 || stringFromAny(content[0]["type"]) != "text" {
t.Fatalf("expected only text content, got %+v", content)
}
for _, key := range []string{"reference_video", "reference_audio"} {
if processed[key] != nil {
t.Fatalf("%s should be removed when capability rejects it: %+v", key, processed)
}
}
if len(result.Log.Changes) == 0 {
t.Fatalf("expected convenience-field filtering to be logged")
}
}
func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) {
body := map[string]any{
"model": "Omni",
"content": []any{
map[string]any{"type": "text", "text": "animate"},
map[string]any{"type": "audio_url", "role": "reference_audio", "audio_url": map[string]any{"url": "https://example.com/ref.mp3"}},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "omni",
Capabilities: map[string]any{
"omni": map[string]any{
"input_audio": false,
"max_audios": 0,
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
for _, change := range result.Log.Changes {
if change.Path == "content[1]" && change.CapabilityPath == "capabilities.omni.input_audio" {
return
}
}
t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes)
}
func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) {
body := map[string]any{
"model": "Seedance",
"duration": 13,
"aspect_ratio": "4:3",
"resolution": "1080p",
"audio": true,
"output_audio": true,
"content": []any{
map[string]any{"type": "text", "text": "animate it"},
map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}},
map[string]any{"type": "image_url", "role": "last_frame", "image_url": map[string]any{"url": "https://example.com/last.png"}},
map[string]any{"type": "audio_url", "role": "reference_audio", "audio_url": map[string]any{"url": "https://example.com/ref.mp3"}},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "image_to_video",
Capabilities: map[string]any{
"image_to_video": map[string]any{
"aspect_ratio_allowed": []any{"16:9", "1:1"},
"duration_options": []any{4, 8, 12},
"input_first_last_frame": false,
"input_audio": false,
"output_audio": false,
"max_images_for_last_frame": 0,
},
},
}
result := preprocessRequestWithLog("videos.generations", body, candidate)
processed := result.Body
if processed["duration"] != float64(12) && processed["duration"] != 12 {
t.Fatalf("duration should be snapped to 12, got %+v", processed["duration"])
}
if processed["aspect_ratio"] != "16:9" {
t.Fatalf("aspect_ratio should fall back to first allowed value, got %+v", processed["aspect_ratio"])
}
if processed["audio"] != nil || processed["output_audio"] != nil {
t.Fatalf("output audio flags should be removed: %+v", processed)
}
for _, item := range contentItems(processed["content"]) {
if stringFromAny(item["role"]) == "last_frame" || isAudioContent(item) {
t.Fatalf("unsupported content remained: %+v", processed["content"])
}
}
foundDuration := false
for _, change := range result.Log.Changes {
if change.Path == "duration" && change.CapabilityPath == "capabilities.image_to_video.duration_options" {
foundDuration = true
break
}
}
if !foundDuration {
t.Fatalf("expected duration adjustment to reference duration_options, got %+v", result.Log.Changes)
}
}
func TestParamProcessorImageResolutionAndOutputCount(t *testing.T) {
body := map[string]any{
"model": "即梦V4.0",
"prompt": "draw",
"size": "2K",
"n": 8,
}
candidate := store.RuntimeModelCandidate{
ModelType: "image_generate",
Capabilities: map[string]any{
"image_generate": map[string]any{
"output_multiple_images": true,
"output_max_images_count": 4,
},
},
}
processed := preprocessRequest("images.generations", body, candidate)
if processed["resolution"] != "2K" {
t.Fatalf("size resolution should be copied to resolution, got %+v", processed)
}
if processed["n"] != 4 {
t.Fatalf("image count should be capped to 4, got %+v", processed["n"])
}
}