65 lines
2.2 KiB
Go
65 lines
2.2 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
scriptengine "github.com/easyai/easyai-ai-gateway/apps/api/internal/script"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
func TestPreprocessRequestWithCustomScript(t *testing.T) {
|
|
service := &Service{scriptExecutor: &scriptengine.Executor{}}
|
|
candidate := store.RuntimeModelCandidate{
|
|
Provider: "universal",
|
|
ModelName: "image-model",
|
|
ModelType: "image_generate",
|
|
Capabilities: map[string]any{
|
|
"image_generate": map[string]any{"max_output_images": 4},
|
|
},
|
|
PlatformConfig: map[string]any{
|
|
"customPreprocessScript": `(params, type, context) => {
|
|
return { prompt: params.prompt + "-" + type, n: 2, provider: context.candidate.provider };
|
|
}`,
|
|
},
|
|
}
|
|
|
|
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 8}, candidate)
|
|
if result.Err != nil {
|
|
t.Fatalf("unexpected preprocess error: %v", result.Err)
|
|
}
|
|
if result.Body["prompt"] != "hello-image_generate" || result.Body["n"].(float64) != 2 {
|
|
t.Fatalf("unexpected body: %#v", result.Body)
|
|
}
|
|
if !result.Log.Changed || len(result.Log.Changes) == 0 {
|
|
t.Fatalf("expected script change in log: %#v", result.Log)
|
|
}
|
|
}
|
|
|
|
func TestPreprocessRequestSkipParamNormalizationSkipsCustomScript(t *testing.T) {
|
|
service := &Service{scriptExecutor: &scriptengine.Executor{}}
|
|
candidate := store.RuntimeModelCandidate{
|
|
ModelName: "image-model",
|
|
ModelType: "image_generate",
|
|
Provider: "universal",
|
|
Capabilities: map[string]any{
|
|
"image_generate": map[string]any{"max_output_images": 1},
|
|
},
|
|
PlatformConfig: map[string]any{
|
|
"skipParamNormalization": true,
|
|
"customPreprocessScript": `(params) => ({ prompt: "changed", n: 1 })`,
|
|
},
|
|
}
|
|
|
|
result := service.preprocessRequestWithScripts(context.Background(), "images.generations", map[string]any{"prompt": "hello", "n": 9}, candidate)
|
|
if result.Err != nil {
|
|
t.Fatalf("unexpected preprocess error: %v", result.Err)
|
|
}
|
|
if result.Body["prompt"] != "hello" || result.Body["n"].(int) != 9 {
|
|
t.Fatalf("skip should keep raw body, got %#v", result.Body)
|
|
}
|
|
if result.Log.Changed || len(result.Log.Changes) != 0 {
|
|
t.Fatalf("skip should not record changes: %#v", result.Log)
|
|
}
|
|
}
|