easyai-ai-gateway/apps/api/internal/runner/param_processor_script_test.go

65 lines
2.2 KiB
Go

package runner
import (
"context"
"testing"
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func TestPreprocessRequestWithCustomScript(t *testing.T) {
service := &Service{scriptExecutor: &scriptengine.Executor{}}
candidate := store.RuntimeModelCandidate{
Provider: "universal",
ModelName: "image-model",
ModelType: "image_generate",
Capabilities: map[string]any{
"image_generate": map[string]any{"max_output_images": 4},
},
PlatformConfig: map[string]any{
"customPreprocessScript": `(params, type, context) => {
return { prompt: params.prompt + "-" + type, n: 2, provider: context.candidate.provider };
}`,
},
}
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 8}, candidate)
if result.Err != nil {
t.Fatalf("unexpected preprocess error: %v", result.Err)
}
if result.Body["prompt"] != "hello-image_generate" || result.Body["n"].(float64) != 2 {
t.Fatalf("unexpected body: %#v", result.Body)
}
if !result.Log.Changed || len(result.Log.Changes) == 0 {
t.Fatalf("expected script change in log: %#v", result.Log)
}
}
func TestPreprocessRequestSkipParamNormalizationSkipsCustomScript(t *testing.T) {
service := &Service{scriptExecutor: &scriptengine.Executor{}}
candidate := store.RuntimeModelCandidate{
ModelName: "image-model",
ModelType: "image_generate",
Provider: "universal",
Capabilities: map[string]any{
"image_generate": map[string]any{"max_output_images": 1},
},
PlatformConfig: map[string]any{
"skipParamNormalization": true,
"customPreprocessScript": `(params) => ({ prompt: "changed", n: 1 })`,
},
}
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 9}, candidate)
if result.Err != nil {
t.Fatalf("unexpected preprocess error: %v", result.Err)
}
if result.Body["prompt"] != "hello" || result.Body["n"].(int) != 9 {
t.Fatalf("skip should keep raw body, got %#v", result.Body)
}
if result.Log.Changed || len(result.Log.Changes) != 0 {
t.Fatalf("skip should not record changes: %#v", result.Log)
}
}