From b9c9f457e9fac5403e7c44bae084385bcf3318da Mon Sep 17 00:00:00 2001 From: wangbo Date: Tue, 12 May 2026 13:54:51 +0800 Subject: [PATCH] feat: add parameter preprocessing audit trail --- .../httpapi/core_flow_integration_test.go | 29 +- apps/api/internal/httpapi/handlers.go | 20 + apps/api/internal/httpapi/server.go | 2 + apps/api/internal/runner/param_processor.go | 1450 +++++++++++++++++ .../internal/runner/param_processor_test.go | 207 +++ apps/api/internal/runner/pricing.go | 2 + apps/api/internal/runner/recording.go | 3 + apps/api/internal/runner/service.go | 136 +- apps/api/internal/store/postgres.go | 18 + apps/api/internal/store/runtime_types.go | 16 + apps/api/internal/store/tasks_runtime.go | 98 ++ .../0033_task_param_preprocessing_logs.sql | 30 + apps/web/src/App.tsx | 1 + apps/web/src/api.ts | 8 + apps/web/src/pages/WorkspacePage.tsx | 197 ++- apps/web/src/styles.css | 116 +- packages/contracts/src/index.ts | 18 + 17 files changed, 2323 insertions(+), 28 deletions(-) create mode 100644 apps/api/internal/runner/param_processor.go create mode 100644 apps/api/internal/runner/param_processor_test.go create mode 100644 apps/api/migrations/0033_task_param_preprocessing_logs.sql diff --git a/apps/api/internal/httpapi/core_flow_integration_test.go b/apps/api/internal/httpapi/core_flow_integration_test.go index 41bf15f..e5198f4 100644 --- a/apps/api/internal/httpapi/core_flow_integration_test.go +++ b/apps/api/internal/httpapi/core_flow_integration_test.go @@ -836,8 +836,10 @@ WHERE reference_type = 'gateway_task' } var imageToVideoTask struct { Task struct { - Status string `json:"status"` - ModelType string `json:"modelType"` + ID string `json:"id"` + Status string `json:"status"` + ModelType string `json:"modelType"` + Metrics map[string]any `json:"metrics"` } `json:"task"` } doJSON(t, server.URL, http.MethodPost, "/api/v1/videos/generations", apiKeyResponse.Secret, map[string]any{ @@ -851,6 +853,29 @@ WHERE reference_type = 'gateway_task' if imageToVideoTask.Task.Status != "succeeded" || imageToVideoTask.Task.ModelType != "image_to_video" { t.Fatalf("image-to-video request should use image_to_video model_type: %+v", imageToVideoTask.Task) } + if _, ok := imageToVideoTask.Task.Metrics["parameterPreprocessing"]; ok { + t.Fatalf("task metrics should not embed full parameter preprocessing log: %+v", imageToVideoTask.Task.Metrics) + } + if imageToVideoTask.Task.Metrics["parameterPreprocessingSummary"] == nil { + t.Fatalf("task metrics should keep lightweight preprocessing summary: %+v", imageToVideoTask.Task.Metrics) + } + var preprocessingDetail struct { + Items []map[string]any `json:"items"` + } + doJSON(t, server.URL, http.MethodGet, "/api/v1/tasks/"+imageToVideoTask.Task.ID+"/param-preprocessing", apiKeyResponse.Secret, nil, http.StatusOK, &preprocessingDetail) + if len(preprocessingDetail.Items) == 0 { + t.Fatalf("task preprocessing endpoint should expose persisted preprocessing logs: %+v", preprocessingDetail) + } + if preprocessingDetail.Items[0]["actualInput"] == nil || preprocessingDetail.Items[0]["convertedOutput"] == nil { + t.Fatalf("preprocessing log should store actual input and converted output: %+v", preprocessingDetail.Items) + } + var preprocessingRows int + if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_task_param_preprocessing_logs WHERE task_id = $1::uuid`, imageToVideoTask.Task.ID).Scan(&preprocessingRows); err != nil { + t.Fatalf("count preprocessing logs: %v", err) + } + if preprocessingRows == 0 { + t.Fatalf("expected preprocessing logs in dedicated table for task %s", imageToVideoTask.Task.ID) + } failoverModel := "phase1-failover-" + suffixText var failedPlatform struct { diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index 7cf2f00..5b23c71 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -795,6 +795,26 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, "get task failed") } +func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) { + task, err := s.store.GetTask(r.Context(), r.PathValue("taskID")) + if err != nil { + if store.IsNotFound(err) { + writeError(w, http.StatusNotFound, "task not found") + return + } + s.logger.Error("get task failed", "error", err) + writeError(w, http.StatusInternalServerError, "get task failed") + return + } + logs, err := s.store.ListTaskParamPreprocessingLogs(r.Context(), task.ID) + if err != nil { + s.logger.Error("list task parameter preprocessing logs failed", "taskID", task.ID, "error", err) + writeError(w, http.StatusInternalServerError, "list task parameter preprocessing logs failed") + return + } + writeJSON(w, http.StatusOK, map[string]any{"items": logs}) +} + func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) { task, err := s.store.GetTask(r.Context(), r.PathValue("taskID")) if err != nil { diff --git a/apps/api/internal/httpapi/server.go b/apps/api/internal/httpapi/server.go index 8eb5a95..ae89c66 100644 --- a/apps/api/internal/httpapi/server.go +++ b/apps/api/internal/httpapi/server.go @@ -87,6 +87,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("GET /api/workspace/wallet/transactions", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listWalletTransactions))) mux.Handle("GET /api/workspace/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) mux.Handle("GET /api/workspace/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) + mux.Handle("GET /api/workspace/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing))) mux.Handle("GET /api/workspace/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents))) mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules))) mux.Handle("GET /api/admin/pricing/rule-sets", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRuleSets))) @@ -123,6 +124,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false))) mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) + mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing))) mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents))) mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true))) mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true))) diff --git a/apps/api/internal/runner/param_processor.go b/apps/api/internal/runner/param_processor.go new file mode 100644 index 0000000..845b65f --- /dev/null +++ b/apps/api/internal/runner/param_processor.go @@ -0,0 +1,1450 @@ +package runner + +import ( + "fmt" + "math" + "strconv" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +type paramProcessContext struct { + modelCapability map[string]any + candidate store.RuntimeModelCandidate + log *parameterPreprocessingLog + aspectRatio string + resolution string +} + +type paramProcessor interface { + Name() string + ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool + Process(params map[string]any, modelType string, context *paramProcessContext) bool +} + +type ParamProcessorChain struct { + processors []paramProcessor +} + +type parameterPreprocessResult struct { + Body map[string]any + Log parameterPreprocessingLog +} + +type parameterPreprocessingLog struct { + ModelType string `json:"modelType"` + Input map[string]any `json:"actualInput"` + Output map[string]any `json:"convertedOutput"` + Changed bool `json:"changed"` + Changes []parameterPreprocessChange `json:"changes"` + Model map[string]any `json:"model,omitempty"` +} + +type parameterPreprocessChange struct { + Processor string `json:"processor"` + Action string `json:"action"` + Path string `json:"path"` + Before any `json:"before"` + After any `json:"after"` + Reason string `json:"reason"` + CapabilityPath string `json:"capabilityPath,omitempty"` + CapabilityValue any `json:"capabilityValue,omitempty"` +} + +func NewParamProcessorChain() ParamProcessorChain { + return ParamProcessorChain{ + processors: []paramProcessor{ + resolutionNormalizeProcessor{}, + aspectRatioProcessor{}, + contentFilterProcessor{}, + inputAudioProcessor{}, + durationProcessor{}, + audioProcessor{}, + imageCountProcessor{}, + }, + } +} + +func preprocessRequest(kind string, body map[string]any, candidate store.RuntimeModelCandidate) map[string]any { + return preprocessRequestWithLog(kind, body, candidate).Body +} + +func preprocessRequestWithLog(kind string, body map[string]any, candidate store.RuntimeModelCandidate) parameterPreprocessResult { + params := cloneMap(body) + modelType := strings.TrimSpace(candidate.ModelType) + if modelType == "" { + modelType = modelTypeFromKind(kind, params) + } + log := parameterPreprocessingLog{ + ModelType: modelType, + Input: cloneMap(params), + Changes: []parameterPreprocessChange{}, + Model: map[string]any{ + "modelName": candidate.ModelName, + "modelAlias": candidate.ModelAlias, + "providerModelName": candidate.ProviderModelName, + "provider": candidate.Provider, + "platformId": candidate.PlatformID, + "platformModelId": candidate.PlatformModelID, + }, + } + context := ¶mProcessContext{ + modelCapability: effectiveModelCapability(candidate), + candidate: candidate, + log: &log, + } + if kind == "videos.generations" { + ensureVideoContent(params, context) + } + chain := NewParamProcessorChain() + processed := chain.Process(params, modelType, context) + log.Output = cloneMap(processed) + log.Changed = len(log.Changes) > 0 + return parameterPreprocessResult{Body: processed, Log: log} +} + +func (chain ParamProcessorChain) Process(params map[string]any, modelType string, context *paramProcessContext) map[string]any { + if params == nil { + return map[string]any{} + } + for _, processor := range chain.processors { + if !processor.ShouldProcess(params, modelType, context) { + continue + } + if !processor.Process(params, modelType, context) { + break + } + } + return params +} + +func (context *paramProcessContext) recordChange(processor string, action string, path string, before any, after any, reason string, capabilityPath string, capabilityValue any) { + if context == nil || context.log == nil { + return + } + context.log.Changes = append(context.log.Changes, parameterPreprocessChange{ + Processor: processor, + Action: action, + Path: path, + Before: cloneAny(before), + After: cloneAny(after), + Reason: reason, + CapabilityPath: capabilityPath, + CapabilityValue: cloneAny(capabilityValue), + }) +} + +func parameterPreprocessingMetrics(log parameterPreprocessingLog) map[string]any { + return map[string]any{ + "parameterPreprocessingSummary": parameterPreprocessingSummary(log), + } +} + +func parameterPreprocessingSummary(log parameterPreprocessingLog) map[string]any { + summary := map[string]any{ + "modelType": log.ModelType, + "changed": log.Changed, + "changeCount": len(log.Changes), + } + if len(log.Changes) == 0 { + return summary + } + actions := make([]string, 0) + paths := make([]string, 0) + capabilityPaths := make([]string, 0) + for _, change := range log.Changes { + appendUniqueString(&actions, change.Action) + appendUniqueString(&paths, change.Path) + appendUniqueString(&capabilityPaths, change.CapabilityPath) + } + summary["actions"] = actions + summary["paths"] = paths + if len(capabilityPaths) > 0 { + summary["capabilityPaths"] = capabilityPaths + } + return summary +} + +type resolutionNormalizeProcessor struct{} + +func (resolutionNormalizeProcessor) Name() string { return "ResolutionNormalizeProcessor" } + +func (resolutionNormalizeProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + if stringFromAny(params["resolution"]) != "" { + return false + } + size := stringFromAny(params["size"]) + if size == "" { + return false + } + return isImageResolution(modelType, size) || isVideoResolution(modelType, size) +} + +func (resolutionNormalizeProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + size := stringFromAny(params["size"]) + if stringFromAny(params["resolution"]) == "" && (isImageResolution(modelType, size) || isVideoResolution(modelType, size)) { + _, capabilityValue := capabilityEvidence(context.modelCapability, modelType, "output_resolutions") + params["resolution"] = size + context.resolution = size + context.recordChange( + "ResolutionNormalizeProcessor", + "set", + "resolution", + nil, + size, + "size 使用分辨率格式,归一到 resolution 供后续能力校验和计费使用。", + capabilityPath(modelType, "output_resolutions"), + capabilityValue, + ) + } + return true +} + +type aspectRatioProcessor struct{} + +func (aspectRatioProcessor) Name() string { return "AspectRatioProcessor" } + +func (aspectRatioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return modelType != "text_generate" && (stringFromAny(params["aspect_ratio"]) != "" || stringFromAny(params["size"]) != "") +} + +func (aspectRatioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil { + return true + } + + aspectRatio := stringFromAny(params["aspect_ratio"]) + if isEmptyParamString(aspectRatio) { + before := params["aspect_ratio"] + delete(params, "aspect_ratio") + context.aspectRatio = "" + context.recordChange( + "AspectRatioProcessor", + "remove", + "aspect_ratio", + before, + nil, + "aspect_ratio 是空值字符串,不能作为有效比例传给上游。", + "", + nil, + ) + return true + } + + resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution) + if resolution == "" { + if values := stringListFromAny(capability["output_resolutions"]); len(values) > 0 { + resolution = values[0] + } else if size := stringFromAny(params["size"]); strings.HasSuffix(size, "K") || strings.HasSuffix(size, "p") { + resolution = size + } + } + + allowed := aspectRatioAllowed(capability["aspect_ratio_allowed"], resolution) + if allowed != nil && len(allowed) == 1 && allowed[0] == "adaptive" { + before := params["aspect_ratio"] + params["aspect_ratio"] = "adaptive" + context.aspectRatio = "adaptive" + if before != "adaptive" { + context.recordChange( + "AspectRatioProcessor", + "adjust", + "aspect_ratio", + before, + "adaptive", + "模型当前分辨率只允许 adaptive 宽高比。", + capabilityPath(modelType, "aspect_ratio_allowed"), + capability["aspect_ratio_allowed"], + ) + } + return true + } + if allowed != nil && len(allowed) == 0 { + before := params["aspect_ratio"] + delete(params, "aspect_ratio") + context.aspectRatio = "" + context.recordChange( + "AspectRatioProcessor", + "remove", + "aspect_ratio", + before, + nil, + "模型能力配置不允许传入任何 aspect_ratio。", + capabilityPath(modelType, "aspect_ratio_allowed"), + capability["aspect_ratio_allowed"], + ) + return true + } + if aspectRatio == "" { + return true + } + if allowed == nil && validAspectRatio(aspectRatio) { + params["aspect_ratio"] = aspectRatio + context.aspectRatio = aspectRatio + return true + } + + processed, ok := validateAndAdjustAspectRatio(aspectRatio, capability, allowed) + if !ok { + before := params["aspect_ratio"] + delete(params, "aspect_ratio") + context.aspectRatio = "" + context.recordChange( + "AspectRatioProcessor", + "remove", + "aspect_ratio", + before, + nil, + "传入的 aspect_ratio 不在模型允许范围内,且没有可用替代值。", + capabilityPath(modelType, "aspect_ratio_allowed"), + capability["aspect_ratio_allowed"], + ) + return true + } + if processed != "" { + before := params["aspect_ratio"] + params["aspect_ratio"] = processed + context.aspectRatio = processed + if before != processed { + path := capabilityPath(modelType, "aspect_ratio_allowed") + value := capability["aspect_ratio_allowed"] + if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok { + ratio, valid := aspectRatioNumber(aspectRatio) + if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] { + path = capabilityPath(modelType, "aspect_ratio_range") + value = capability["aspect_ratio_range"] + } + } + context.recordChange( + "AspectRatioProcessor", + "adjust", + "aspect_ratio", + before, + processed, + "传入的 aspect_ratio 不符合模型能力配置,已调整为允许值。", + path, + value, + ) + } + } + return true +} + +type contentFilterProcessor struct{} + +func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" } + +func (contentFilterProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + _, ok := params["content"] + return ok +} + +func (contentFilterProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + content := contentItems(params["content"]) + if len(content) == 0 { + return true + } + + if isOmniVideoLike(context) { + filtered := filterUnsupportedOmniVideoContent(content, context) + params["content"] = mapsToAnySlice(filtered) + syncVideoConvenienceFields(params, filtered, context) + return true + } + + downgradeReferenceImageIfNeeded(params, content, modelType, context) + if modelType == "video_generate" || modelType == "text_to_video" { + next := make([]map[string]any, 0, len(content)) + for index, item := range content { + if isImageContent(item) { + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "当前候选模型没有图像参考输入模式,需移除 image_url。", + capabilityPath(modelType, ""), + capabilityForType(context.modelCapability, modelType), + ) + continue + } + next = append(next, item) + } + content = next + } + if modelType == "image_to_video" || modelType == "omni_video" || modelType == "omni" { + if !supportsFirstAndLastFrame(context.modelCapability, modelType) { + next := make([]map[string]any, 0, len(content)) + for index, item := range content { + if stringFromAny(item["role"]) == "last_frame" { + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "模型不支持首尾帧输入,已移除 last_frame。", + capabilityPath(modelType, "input_first_last_frame"), + map[string]any{ + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), + }, + ) + continue + } + next = append(next, item) + } + content = next + deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"last_frame", "lastFrame"}, "模型不支持首尾帧输入,已移除快捷字段。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{ + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), + }) + } + } + params["content"] = mapsToAnySlice(content) + return true +} + +type inputAudioProcessor struct{} + +func (inputAudioProcessor) Name() string { return "InputAudioProcessor" } + +func (inputAudioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + if !isVideoModelType(modelType) { + return false + } + content := contentItems(params["content"]) + for _, item := range content { + if isAudioContent(item) { + return true + } + } + return false +} + +func (inputAudioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + content := contentItems(params["content"]) + if len(content) == 0 { + return true + } + supportsInputAudio := false + if len(context.modelCapability) > 0 { + if isOmniVideoLike(context) { + supportsInputAudio = supportsOmniAudioReference(context) + } else if capability := capabilityForType(context.modelCapability, modelType); capability != nil { + supportsInputAudio = boolFromAny(capability["input_audio"]) + } + } + if supportsInputAudio { + return true + } + next := make([]map[string]any, 0, len(content)) + for index, item := range content { + if isAudioContent(item) { + path, value := audioInputCapabilityEvidence(context, modelType) + context.recordChange( + "InputAudioProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "模型能力未开启输入音频,已移除 audio_url。", + path, + value, + ) + continue + } + next = append(next, item) + } + params["content"] = mapsToAnySlice(next) + path, value := audioInputCapabilityEvidence(context, modelType) + deleteFieldsWithLog(params, context, "InputAudioProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "模型能力未开启输入音频,已移除音频参考快捷字段。", path, value) + return true +} + +type durationProcessor struct{} + +func (durationProcessor) Name() string { return "DurationProcessor" } + +func (durationProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return isVideoModelType(modelType) && params["duration"] != nil +} + +func (durationProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil { + return true + } + duration := floatFromAny(params["duration"]) + if duration <= 0 { + return true + } + resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution) + modeKey := videoModeKey(params) + if options := scopedNumberList(capability["duration_options"], resolution, modeKey); len(options) > 0 { + normalized := closestNumber(duration, options) + params["duration"] = normalized + syncDurationSeconds(params) + if normalized != duration { + context.recordChange( + "DurationProcessor", + "adjust", + "duration", + duration, + normalized, + "duration 不在模型固定时长选项内,已调整为最近的允许值。", + capabilityPath(modelType, "duration_options"), + capability["duration_options"], + ) + } + return true + } + if minValue, maxValue, ok := scopedRange(capability["duration_range"], resolution, modeKey); ok { + step := durationStep(capability["duration_step"], resolution, modeKey) + normalized := normalizeDurationByRange(duration, minValue, maxValue, step) + params["duration"] = normalized + syncDurationSeconds(params) + if normalized != duration { + context.recordChange( + "DurationProcessor", + "adjust", + "duration", + duration, + normalized, + "duration 超出模型时长范围或步进配置,已按能力配置归一。", + capabilityPath(modelType, "duration_range"), + map[string]any{ + "duration_range": capability["duration_range"], + "duration_step": capability["duration_step"], + }, + ) + } + } + return true +} + +type audioProcessor struct{} + +func (audioProcessor) Name() string { return "AudioProcessor" } + +func (audioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return isVideoModelType(modelType) && (params["audio"] != nil || params["output_audio"] != nil) +} + +func (audioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil || !boolFromAny(capability["output_audio"]) { + for _, key := range []string{"audio", "output_audio"} { + if before, ok := params[key]; ok { + delete(params, key) + context.recordChange( + "AudioProcessor", + "remove", + key, + before, + nil, + "模型能力未开启输出音频,已移除音频输出参数。", + capabilityPath(modelType, "output_audio"), + capabilityValue(context.modelCapability, modelType, "output_audio"), + ) + } + } + } + return true +} + +type imageCountProcessor struct{} + +func (imageCountProcessor) Name() string { return "ImageCountProcessor" } + +func (imageCountProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return modelType == "image_generate" || modelType == "image_edit" +} + +func (imageCountProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil || !boolFromAny(capability["output_multiple_images"]) { + return true + } + maxCount := int(math.Round(floatFromAny(capability["output_max_images_count"]))) + if maxCount <= 0 { + return true + } + count := int(math.Round(floatFromAny(params["n"]))) + if count <= 0 { + count = int(math.Round(floatFromAny(params["batch_size"]))) + } + if count <= 0 { + count = 1 + } + if count > maxCount { + before := count + count = maxCount + context.recordChange( + "ImageCountProcessor", + "adjust", + "n", + before, + count, + "请求图片数量超过模型输出上限,已按 output_max_images_count 截断。", + capabilityPath(modelType, "output_max_images_count"), + capability["output_max_images_count"], + ) + } + params["n"] = count + return true +} + +func ensureVideoContent(params map[string]any, context *paramProcessContext) { + if len(contentItems(params["content"])) > 0 { + return + } + content := make([]map[string]any, 0) + if prompt := firstNonEmptyString(stringFromAny(params["prompt"]), stringFromAny(params["input"])); prompt != "" { + content = append(content, map[string]any{"type": "text", "text": prompt}) + } + appendURL := func(kind string, role string, url string) { + url = strings.TrimSpace(url) + if url == "" { + return + } + item := map[string]any{"type": kind, "role": role} + switch kind { + case "image_url": + item["image_url"] = map[string]any{"url": url} + case "video_url": + item["video_url"] = map[string]any{"url": url} + case "audio_url": + item["audio_url"] = map[string]any{"url": url} + } + content = append(content, item) + } + + firstFrame := firstNonEmptyStringValue(params, "first_frame", "firstFrame") + appendURL("image_url", "first_frame", firstFrame) + appendURL("image_url", "last_frame", firstNonEmptyStringValue(params, "last_frame", "lastFrame")) + imageURLs := firstNonEmptyStringListFromAny(params["image"], params["images"], params["image_url"], params["imageUrl"], params["image_urls"], params["imageUrls"]) + if firstFrame == "" && len(imageURLs) > 0 { + appendURL("image_url", "first_frame", imageURLs[0]) + imageURLs = imageURLs[1:] + } + for _, url := range imageURLs { + appendURL("image_url", "reference_image", url) + } + for _, url := range firstNonEmptyStringListFromAny(params["reference_image"], params["referenceImage"]) { + appendURL("image_url", "reference_image", url) + } + for _, url := range firstNonEmptyStringListFromAny(params["video"], params["video_url"], params["videoUrl"], params["reference_video"], params["referenceVideo"]) { + appendURL("video_url", "reference_video", url) + } + for _, url := range firstNonEmptyStringListFromAny(params["audio_url"], params["audioUrl"], params["reference_audio"], params["referenceAudio"]) { + appendURL("audio_url", "reference_audio", url) + } + if len(content) > 0 { + params["content"] = mapsToAnySlice(content) + context.recordChange( + "ContentBuildProcessor", + "set", + "content", + nil, + params["content"], + "将 prompt/first_frame/reference_* 等快捷字段转换为 content 数组,后续处理器可按模型能力逐项过滤。", + "", + nil, + ) + } +} + +func effectiveModelCapability(candidate store.RuntimeModelCandidate) map[string]any { + base := cloneMap(candidate.Capabilities) + for key, value := range candidate.CapabilityOverride { + if baseChild, ok := base[key].(map[string]any); ok { + if overrideChild, ok := value.(map[string]any); ok { + base[key] = mergeMap(baseChild, overrideChild) + continue + } + } + base[key] = cloneAny(value) + } + return base +} + +func filterUnsupportedOmniVideoContent(content []map[string]any, context *paramProcessContext) []map[string]any { + capability := omniVideoCapability(context) + maxVideos := math.Inf(1) + if capability != nil { + if value, ok := numericField(capability, "max_videos"); ok { + maxVideos = value + } + } + maxAudios := 0.0 + if capability != nil { + if value, ok := numericField(capability, "max_audios"); ok { + maxAudios = value + } else if supportsOmniAudioReference(context) { + maxAudios = math.Inf(1) + } + } + + videoCount := 0.0 + audioCount := 0.0 + out := make([]map[string]any, 0, len(content)) + for index, item := range content { + if isVideoContent(item) { + if !supportsOmniVideoReference(item, capability) { + path, value := omniCapabilityEvidence(context, "supported_modes") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "视频参考类型不在 omni_video.supported_modes 允许范围内。", + path, + value, + ) + continue + } + if videoCount >= maxVideos { + path, value := omniCapabilityEvidence(context, "max_videos") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "视频参考数量超过 omni_video.max_videos 限制。", + path, + value, + ) + continue + } + videoCount++ + out = append(out, item) + continue + } + if isAudioContent(item) { + if !supportsOmniAudioReference(context) { + path, value := omniCapabilityEvidence(context, "input_audio") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "模型能力不支持音频参考,已移除 audio_url。", + path, + mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")), + ) + continue + } + if audioCount >= maxAudios { + path, value := omniCapabilityEvidence(context, "max_audios") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "音频参考数量超过 omni_video.max_audios 限制。", + path, + value, + ) + continue + } + audioCount++ + out = append(out, item) + continue + } + out = append(out, item) + } + return out +} + +func isOmniVideoLike(context *paramProcessContext) bool { + modelType := strings.TrimSpace(context.candidate.ModelType) + return modelType == "omni_video" || + modelType == "omni" || + context.modelCapability["omni_video"] != nil || + context.modelCapability["omni"] != nil +} + +func omniVideoCapability(context *paramProcessContext) map[string]any { + if capability := capabilityForType(context.modelCapability, "omni_video"); capability != nil { + return capability + } + return capabilityForType(context.modelCapability, "omni") +} + +func supportsOmniAudioReference(context *paramProcessContext) bool { + capability := omniVideoCapability(context) + return capability != nil && (boolFromAny(capability["input_audio"]) || floatFromAny(capability["max_audios"]) > 0) +} + +func supportsOmniVideoReference(item map[string]any, capability map[string]any) bool { + if capability == nil { + return true + } + if value, ok := numericField(capability, "max_videos"); ok && value == 0 { + return false + } + supportedModes := stringListFromAny(capability["supported_modes"]) + supportsReference := containsString(supportedModes, "video_reference") + supportsEdit := containsString(supportedModes, "video_edit") + video, _ := item["video_url"].(map[string]any) + referType := stringFromAny(video["refer_type"]) + isEditVideo := stringFromAny(item["role"]) == "video_base" || referType == "base" + isReferenceVideo := stringFromAny(item["role"]) == "video_feature" || + stringFromAny(item["role"]) == "reference_video" || + referType == "feature" + if isEditVideo { + return supportsEdit + } + if isReferenceVideo { + return supportsReference + } + return supportsReference || supportsEdit +} + +func downgradeReferenceImageIfNeeded(params map[string]any, content []map[string]any, modelType string, context *paramProcessContext) { + if modelType != "image_to_video" && modelType != "video_generate" && modelType != "video_edit" && modelType != "omni_video" && modelType != "omni" { + return + } + if supportsReferenceImage(context.modelCapability, modelType) { + return + } + count := 0 + for index, item := range content { + if stringFromAny(item["type"]) == "image_url" && stringFromAny(item["role"]) == "reference_image" { + before := cloneMap(item) + item["role"] = "first_frame" + context.recordChange( + "ContentFilterProcessor", + "adjust", + fmt.Sprintf("content[%d].role", index), + before, + item, + "模型不支持 reference_image,已降级为 first_frame。", + capabilityPath(modelType, "input_reference_generate_single"), + map[string]any{ + "input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"), + "input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"), + "max_images": capabilityValue(context.modelCapability, modelType, "max_images"), + }, + ) + count++ + } + } + if count > 0 { + appendParamWarning(params, "reference_image is unsupported by the selected model and was downgraded to first_frame") + } +} + +func supportsReferenceImage(modelCapability map[string]any, modelType string) bool { + candidates := []map[string]any{} + if capability := capabilityForType(modelCapability, modelType); capability != nil { + candidates = append(candidates, capability) + } + if modelType != "image_to_video" { + if capability := capabilityForType(modelCapability, "image_to_video"); capability != nil { + candidates = append(candidates, capability) + } + } + if len(candidates) == 0 { + return true + } + for _, capability := range candidates { + _, hasSingle := capability["input_reference_generate_single"] + _, hasMultiple := capability["input_reference_generate_multiple"] + if hasSingle || hasMultiple { + if boolFromAny(capability["input_reference_generate_single"]) || boolFromAny(capability["input_reference_generate_multiple"]) { + return true + } + continue + } + if value, ok := numericField(capability, "max_images"); ok { + if value > 1 { + return true + } + continue + } + } + return false +} + +func supportsFirstAndLastFrame(modelCapability map[string]any, modelType string) bool { + capability := capabilityForType(modelCapability, modelType) + if capability == nil { + return false + } + return boolFromAny(capability["input_first_last_frame"]) || floatFromAny(capability["max_images_for_last_frame"]) > 0 +} + +func validateAndAdjustAspectRatio(aspectRatio string, capability map[string]any, allowed []string) (string, bool) { + if !isMediaModelTypeWithAspectRatio(capability) { + return "", false + } + if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok { + ratio, valid := aspectRatioNumber(aspectRatio) + if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] { + return adjustAspectRatioToRange(aspectRatio, ratioRange[0], ratioRange[1], allowed), true + } + } + if allowed == nil { + return aspectRatio, true + } + if len(allowed) == 0 { + return "", false + } + if (aspectRatio == "adaptive" || aspectRatio == "keep_ratio") && !containsString(allowed, aspectRatio) { + return "", false + } + if containsString(allowed, aspectRatio) { + return aspectRatio, true + } + return allowed[0], true +} + +func isMediaModelTypeWithAspectRatio(capability map[string]any) bool { + return capability != nil +} + +func aspectRatioAllowed(value any, resolution string) []string { + switch typed := value.(type) { + case []any: + return stringListFromAny(typed) + case []string: + return typed + case map[string]any: + if resolution != "" { + if values := stringListFromAny(typed[resolution]); len(values) > 0 { + return values + } + } + return nil + default: + return nil + } +} + +func scopedNumberList(value any, scopes ...string) []float64 { + switch typed := value.(type) { + case []any: + out := make([]float64, 0, len(typed)) + for _, item := range typed { + if number := floatFromAny(item); number > 0 { + out = append(out, number) + } + } + return out + case []float64: + return typed + case []int: + out := make([]float64, 0, len(typed)) + for _, item := range typed { + out = append(out, float64(item)) + } + return out + case map[string]any: + for _, scope := range scopes { + if scope == "" { + continue + } + if values := scopedNumberList(typed[scope]); len(values) > 0 { + return values + } + } + for _, item := range typed { + if values := scopedNumberList(item); len(values) > 0 { + return values + } + } + } + return nil +} + +func scopedRange(value any, scopes ...string) (float64, float64, bool) { + if pair, ok := numberPair(value); ok { + return pair[0], pair[1], true + } + if typed, ok := value.(map[string]any); ok { + for _, scope := range scopes { + if scope == "" { + continue + } + if minValue, maxValue, ok := scopedRange(typed[scope]); ok { + return minValue, maxValue, true + } + } + for _, item := range typed { + if minValue, maxValue, ok := scopedRange(item); ok { + return minValue, maxValue, true + } + } + } + return 0, 0, false +} + +func durationStep(value any, scopes ...string) float64 { + if step := floatFromAny(value); step > 0 { + return step + } + if typed, ok := value.(map[string]any); ok { + for _, scope := range scopes { + if scope == "" { + continue + } + if step := durationStep(typed[scope]); step > 0 { + return step + } + } + for _, item := range typed { + if step := durationStep(item); step > 0 { + return step + } + } + } + return 0 +} + +func normalizeDurationByRange(target float64, minValue float64, maxValue float64, step float64) float64 { + clamped := math.Min(math.Max(target, minValue), maxValue) + if step <= 0 { + return clamped + } + snapped := math.Round((clamped-minValue)/step)*step + minValue + return math.Round(snapped*1_000_000) / 1_000_000 +} + +func closestNumber(target float64, values []float64) float64 { + if len(values) == 0 { + return target + } + closest := values[0] + minDiff := math.Abs(target - closest) + for _, value := range values[1:] { + diff := math.Abs(target - value) + if diff < minDiff { + minDiff = diff + closest = value + } + } + return closest +} + +func videoModeKey(params map[string]any) string { + content := contentItems(params["content"]) + hasFirstFrame := false + hasLastFrame := false + for _, item := range content { + switch stringFromAny(item["role"]) { + case "first_frame": + hasFirstFrame = true + case "last_frame": + hasLastFrame = true + } + } + switch { + case hasFirstFrame && hasLastFrame: + return "input_first_last_frame" + case hasFirstFrame: + return "input_first_frame" + case hasLastFrame: + return "input_last_frame" + default: + return "" + } +} + +func syncDurationSeconds(params map[string]any) { + if params["duration_seconds"] != nil { + params["duration_seconds"] = params["duration"] + } +} + +func syncVideoConvenienceFields(params map[string]any, content []map[string]any, context *paramProcessContext) { + hasVideo := false + hasAudio := false + for _, item := range content { + hasVideo = hasVideo || isVideoContent(item) + hasAudio = hasAudio || isAudioContent(item) + } + if !hasVideo { + path, value := omniCapabilityEvidence(context, "supported_modes") + deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"video", "video_url", "videoUrl", "reference_video", "referenceVideo"}, "对应视频 content 已被模型能力过滤,移除视频参考快捷字段。", path, value) + } + if !hasAudio { + path, value := omniCapabilityEvidence(context, "input_audio") + deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "对应音频 content 已被模型能力过滤,移除音频参考快捷字段。", path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios"))) + } +} + +func deleteFieldsWithLog(params map[string]any, context *paramProcessContext, processor string, keys []string, reason string, capabilityPath string, capabilityValue any) { + for _, key := range keys { + if before, ok := params[key]; ok { + delete(params, key) + context.recordChange(processor, "remove", key, before, nil, reason, capabilityPath, capabilityValue) + } + } +} + +func appendParamWarning(params map[string]any, warning string) { + warnings, _ := params["_param_warnings"].([]any) + for _, item := range warnings { + if stringFromAny(item) == warning { + return + } + } + params["_param_warnings"] = append(warnings, warning) +} + +func filterContent(content []map[string]any, keep func(map[string]any) bool) []map[string]any { + out := make([]map[string]any, 0, len(content)) + for _, item := range content { + if keep(item) { + out = append(out, item) + } + } + return out +} + +func contentItems(value any) []map[string]any { + switch typed := value.(type) { + case []any: + out := make([]map[string]any, 0, len(typed)) + for _, item := range typed { + if object, ok := item.(map[string]any); ok { + out = append(out, cloneMap(object)) + } + } + return out + case []map[string]any: + out := make([]map[string]any, 0, len(typed)) + for _, item := range typed { + out = append(out, cloneMap(item)) + } + return out + default: + return nil + } +} + +func mapsToAnySlice(values []map[string]any) []any { + out := make([]any, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out +} + +func isImageContent(item map[string]any) bool { + return stringFromAny(item["type"]) == "image_url" || item["image_url"] != nil +} + +func isVideoContent(item map[string]any) bool { + return stringFromAny(item["type"]) == "video_url" || item["video_url"] != nil +} + +func isAudioContent(item map[string]any) bool { + return stringFromAny(item["type"]) == "audio_url" || item["audio_url"] != nil +} + +func capabilityForType(capabilities map[string]any, modelType string) map[string]any { + if capabilities == nil { + return nil + } + if typed, ok := capabilities[modelType].(map[string]any); ok { + return typed + } + return nil +} + +func capabilityPath(modelType string, key string) string { + modelType = strings.TrimSpace(modelType) + if modelType == "" { + return "" + } + if strings.TrimSpace(key) == "" { + return "capabilities." + modelType + } + return "capabilities." + modelType + "." + key +} + +func capabilityValue(capabilities map[string]any, modelType string, key string) any { + capability := capabilityForType(capabilities, modelType) + if capability == nil { + return nil + } + return cloneAny(capability[key]) +} + +func capabilityEvidence(capabilities map[string]any, modelType string, key string) (string, any) { + return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key) +} + +func audioInputCapabilityEvidence(context *paramProcessContext, modelType string) (string, any) { + if isOmniVideoLike(context) { + path, value := omniCapabilityEvidence(context, "input_audio") + return path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")) + } + return capabilityEvidence(context.modelCapability, modelType, "input_audio") +} + +func omniCapabilityType(context *paramProcessContext) string { + if context != nil && capabilityForType(context.modelCapability, "omni_video") != nil { + return "omni_video" + } + if context != nil && capabilityForType(context.modelCapability, "omni") != nil { + return "omni" + } + return "omni_video" +} + +func omniCapabilityEvidence(context *paramProcessContext, key string) (string, any) { + modelType := omniCapabilityType(context) + var capabilities map[string]any + if context != nil { + capabilities = context.modelCapability + } + return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key) +} + +func omniCapabilityBundle(context *paramProcessContext, keys ...string) map[string]any { + modelType := omniCapabilityType(context) + var capabilities map[string]any + if context != nil { + capabilities = context.modelCapability + } + out := map[string]any{} + for _, key := range keys { + out[key] = capabilityValue(capabilities, modelType, key) + } + return out +} + +func numericField(values map[string]any, key string) (float64, bool) { + if values == nil { + return 0, false + } + if _, ok := values[key]; !ok { + return 0, false + } + return floatFromAny(values[key]), true +} + +func boolFromAny(value any) bool { + typed, _ := value.(bool) + return typed +} + +func firstNonEmptyStringValue(values map[string]any, keys ...string) string { + for _, key := range keys { + if value := stringFromAny(values[key]); value != "" { + return value + } + } + return "" +} + +func firstNonEmptyStringListFromAny(values ...any) []string { + for _, value := range values { + items := stringListFromAny(value) + if len(items) > 0 { + return items + } + } + return nil +} + +func stringListFromAny(value any) []string { + switch typed := value.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if text := strings.TrimSpace(item); text != "" { + out = append(out, text) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if text := stringFromAny(item); text != "" { + out = append(out, text) + } + } + return out + case string: + if strings.TrimSpace(typed) == "" { + return nil + } + return []string{strings.TrimSpace(typed)} + default: + return nil + } +} + +func containsString(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} + +func appendUniqueString(values *[]string, value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + for _, existing := range *values { + if existing == value { + return + } + } + *values = append(*values, value) +} + +func numberPair(value any) ([2]float64, bool) { + switch typed := value.(type) { + case []any: + if len(typed) < 2 { + return [2]float64{}, false + } + return [2]float64{floatFromAny(typed[0]), floatFromAny(typed[1])}, true + case []float64: + if len(typed) < 2 { + return [2]float64{}, false + } + return [2]float64{typed[0], typed[1]}, true + case []int: + if len(typed) < 2 { + return [2]float64{}, false + } + return [2]float64{float64(typed[0]), float64(typed[1])}, true + default: + return [2]float64{}, false + } +} + +func validAspectRatio(value string) bool { + if value == "adaptive" || value == "keep_ratio" { + return true + } + _, ok := aspectRatioNumber(value) + return ok +} + +func aspectRatioNumber(value string) (float64, bool) { + parts := strings.Split(value, ":") + if len(parts) != 2 { + return 0, false + } + width := parsePositiveFloat(parts[0]) + height := parsePositiveFloat(parts[1]) + if width <= 0 || height <= 0 { + return 0, false + } + return width / height, true +} + +func adjustAspectRatioToRange(value string, minValue float64, maxValue float64, allowed []string) string { + current, ok := aspectRatioNumber(value) + if !ok { + if len(allowed) > 0 { + return allowed[0] + } + return "1:1" + } + if len(allowed) > 0 { + closest := "" + minDiff := math.Inf(1) + for _, candidate := range allowed { + ratio, ok := aspectRatioNumber(candidate) + if !ok || ratio < minValue || ratio > maxValue { + continue + } + diff := math.Abs(ratio - current) + if diff < minDiff { + minDiff = diff + closest = candidate + } + } + if closest != "" { + return closest + } + } + if current < minValue { + return ratioString(minValue) + } + return ratioString(maxValue) +} + +func ratioString(value float64) string { + if value <= 0 { + return "1:1" + } + return strings.TrimRight(strings.TrimRight(strconv.FormatFloat(value, 'f', 6, 64), "0"), ".") + ":1" +} + +func parsePositiveFloat(value string) float64 { + for _, r := range strings.TrimSpace(value) { + if r < '0' || r > '9' { + if r != '.' { + return 0 + } + } + } + out, _ := strconv.ParseFloat(strings.TrimSpace(value), 64) + return out +} + +func isEmptyParamString(value string) bool { + normalized := strings.ToLower(strings.TrimSpace(value)) + return normalized == "null" || normalized == "undefined" +} + +func isImageResolution(modelType string, value string) bool { + return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "4K", "8K"}, value) +} + +func isVideoResolution(modelType string, value string) bool { + return isVideoModelType(modelType) && containsString([]string{"480p", "720p", "1080p", "1440p", "2160p"}, value) +} + +func isVideoModelType(modelType string) bool { + return modelType == "video_generate" || modelType == "text_to_video" || modelType == "image_to_video" || modelType == "video_edit" || modelType == "omni_video" || modelType == "omni" +} + +func cloneMap(values map[string]any) map[string]any { + out := map[string]any{} + for key, value := range values { + out[key] = cloneAny(value) + } + return out +} + +func cloneAny(value any) any { + switch typed := value.(type) { + case map[string]any: + return cloneMap(typed) + case []any: + out := make([]any, 0, len(typed)) + for _, item := range typed { + out = append(out, cloneAny(item)) + } + return out + case []map[string]any: + out := make([]any, 0, len(typed)) + for _, item := range typed { + out = append(out, cloneMap(item)) + } + return out + default: + return value + } +} diff --git a/apps/api/internal/runner/param_processor_test.go b/apps/api/internal/runner/param_processor_test.go new file mode 100644 index 0000000..235fa20 --- /dev/null +++ b/apps/api/internal/runner/param_processor_test.go @@ -0,0 +1,207 @@ +package runner + +import ( + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestParamProcessorOmniFiltersUnsupportedVideoAndAudioContent(t *testing.T) { + body := map[string]any{ + "model": "可灵O1", + "prompt": "edit the source video", + "content": []any{ + map[string]any{"type": "text", "text": "edit the source video"}, + map[string]any{"type": "video_url", "role": "video_base", "video_url": map[string]any{"url": "https://example.com/base.mp4", "refer_type": "base"}}, + map[string]any{"type": "video_url", "role": "reference_video", "video_url": map[string]any{"url": "https://example.com/ref.mp4", "refer_type": "feature"}}, + map[string]any{"type": "audio_url", "role": "reference_audio", "audio_url": map[string]any{"url": "https://example.com/ref.mp3"}}, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "omni_video", + Capabilities: map[string]any{ + "omni_video": map[string]any{ + "supported_modes": []any{"video_edit"}, + "max_videos": 1, + "input_audio": false, + "max_audios": 0, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + processed := result.Body + content := contentItems(processed["content"]) + if len(content) != 2 { + t.Fatalf("expected text plus one video item, got %+v", content) + } + if stringFromAny(content[1]["role"]) != "video_base" || isAudioContent(content[1]) { + t.Fatalf("unexpected retained content: %+v", content) + } + for _, item := range content { + if isAudioContent(item) || stringFromAny(item["role"]) == "reference_video" { + t.Fatalf("unsupported content was not filtered: %+v", content) + } + } + if !result.Log.Changed || len(result.Log.Changes) < 2 { + t.Fatalf("expected preprocessing log with filtered video and audio changes, got %+v", result.Log) + } + if result.Log.Input["content"] == nil || result.Log.Output["content"] == nil { + t.Fatalf("preprocessing log should keep actual input and converted output: %+v", result.Log) + } + foundAudioReason := false + for _, change := range result.Log.Changes { + if change.Path == "content[3]" && change.CapabilityPath == "capabilities.omni_video.input_audio" { + foundAudioReason = true + break + } + } + if !foundAudioReason { + t.Fatalf("expected audio filtering reason to reference omni_video.input_audio, got %+v", result.Log.Changes) + } +} + +func TestParamProcessorOmniFiltersConvenienceReferenceFields(t *testing.T) { + body := map[string]any{ + "model": "可灵V3多模态", + "prompt": "text only", + "reference_video": "https://example.com/ref.mp4", + "reference_audio": "https://example.com/ref.mp3", + } + candidate := store.RuntimeModelCandidate{ + ModelType: "omni_video", + Capabilities: map[string]any{ + "omni_video": map[string]any{ + "supported_modes": []any{"text_to_video"}, + "max_videos": 0, + "input_audio": false, + "max_audios": 0, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + processed := result.Body + content := contentItems(processed["content"]) + if len(content) != 1 || stringFromAny(content[0]["type"]) != "text" { + t.Fatalf("expected only text content, got %+v", content) + } + for _, key := range []string{"reference_video", "reference_audio"} { + if processed[key] != nil { + t.Fatalf("%s should be removed when capability rejects it: %+v", key, processed) + } + } + if len(result.Log.Changes) == 0 { + t.Fatalf("expected convenience-field filtering to be logged") + } +} + +func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) { + body := map[string]any{ + "model": "Omni", + "content": []any{ + map[string]any{"type": "text", "text": "animate"}, + map[string]any{"type": "audio_url", "role": "reference_audio", "audio_url": map[string]any{"url": "https://example.com/ref.mp3"}}, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "omni", + Capabilities: map[string]any{ + "omni": map[string]any{ + "input_audio": false, + "max_audios": 0, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + for _, change := range result.Log.Changes { + if change.Path == "content[1]" && change.CapabilityPath == "capabilities.omni.input_audio" { + return + } + } + t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes) +} + +func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) { + body := map[string]any{ + "model": "Seedance", + "duration": 13, + "aspect_ratio": "4:3", + "resolution": "1080p", + "audio": true, + "output_audio": true, + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}}, + map[string]any{"type": "image_url", "role": "last_frame", "image_url": map[string]any{"url": "https://example.com/last.png"}}, + map[string]any{"type": "audio_url", "role": "reference_audio", "audio_url": map[string]any{"url": "https://example.com/ref.mp3"}}, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "aspect_ratio_allowed": []any{"16:9", "1:1"}, + "duration_options": []any{4, 8, 12}, + "input_first_last_frame": false, + "input_audio": false, + "output_audio": false, + "max_images_for_last_frame": 0, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + processed := result.Body + if processed["duration"] != float64(12) && processed["duration"] != 12 { + t.Fatalf("duration should be snapped to 12, got %+v", processed["duration"]) + } + if processed["aspect_ratio"] != "16:9" { + t.Fatalf("aspect_ratio should fall back to first allowed value, got %+v", processed["aspect_ratio"]) + } + if processed["audio"] != nil || processed["output_audio"] != nil { + t.Fatalf("output audio flags should be removed: %+v", processed) + } + for _, item := range contentItems(processed["content"]) { + if stringFromAny(item["role"]) == "last_frame" || isAudioContent(item) { + t.Fatalf("unsupported content remained: %+v", processed["content"]) + } + } + foundDuration := false + for _, change := range result.Log.Changes { + if change.Path == "duration" && change.CapabilityPath == "capabilities.image_to_video.duration_options" { + foundDuration = true + break + } + } + if !foundDuration { + t.Fatalf("expected duration adjustment to reference duration_options, got %+v", result.Log.Changes) + } +} + +func TestParamProcessorImageResolutionAndOutputCount(t *testing.T) { + body := map[string]any{ + "model": "即梦V4.0", + "prompt": "draw", + "size": "2K", + "n": 8, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "image_generate", + Capabilities: map[string]any{ + "image_generate": map[string]any{ + "output_multiple_images": true, + "output_max_images_count": 4, + }, + }, + } + + processed := preprocessRequest("images.generations", body, candidate) + if processed["resolution"] != "2K" { + t.Fatalf("size resolution should be copied to resolution, got %+v", processed) + } + if processed["n"] != 4 { + t.Fatalf("image count should be capped to 4, got %+v", processed["n"]) + } +} diff --git a/apps/api/internal/runner/pricing.go b/apps/api/internal/runner/pricing.go index 3ced86e..0eaf684 100644 --- a/apps/api/internal/runner/pricing.go +++ b/apps/api/internal/runner/pricing.go @@ -16,11 +16,13 @@ type EstimateResult struct { } func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) { + body = normalizeRequest(kind, body) candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user) if err != nil { return EstimateResult{}, err } candidate := candidates[0] + body = preprocessRequest(kind, body, candidate) return EstimateResult{ Items: s.estimatedBillings(ctx, user, kind, body, candidate), Resolver: "effective-pricing-v1", diff --git a/apps/api/internal/runner/recording.go b/apps/api/internal/runner/recording.go index d70a46c..502ee74 100644 --- a/apps/api/internal/runner/recording.go +++ b/apps/api/internal/runner/recording.go @@ -243,6 +243,9 @@ func summarizeAttempts(attempts []store.TaskAttempt) []map[string]any { if trace, ok := attempt.Metrics["trace"]; ok { item["trace"] = trace } + if preprocessing, ok := attempt.Metrics["parameterPreprocessingSummary"]; ok { + item["parameterPreprocessingSummary"] = preprocessing + } items = append(items, item) } return items diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index 8fc8780..ae0d2b3 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -96,11 +96,24 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut } return Result{Task: failed, Output: failed.Result}, err } + firstCandidateBody := body + normalizedModelType := modelType + var firstPreprocessing parameterPreprocessingLog if len(candidates) > 0 { - estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, body, candidates[0]) + preprocessing := preprocessRequestWithLog(task.Kind, body, candidates[0]) + firstCandidateBody = preprocessing.Body + firstPreprocessing = preprocessing.Log + normalizedModelType = candidates[0].ModelType + if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, firstCandidateBody); err != nil { + return Result{}, err + } + estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0]) if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil { if errors.Is(err, store.ErrInsufficientWalletBalance) { - failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err) + if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil { + return Result{}, logErr + } + failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing)) if finishErr != nil { return Result{}, finishErr } @@ -109,7 +122,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut return Result{}, err } } - if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil { + if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil { return Result{}, err } @@ -122,6 +135,7 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut attemptNo := task.AttemptCount var lastErr error var lastCandidate store.RuntimeModelCandidate + var lastPreprocessing *parameterPreprocessingLog candidatesLoop: for index, candidate := range candidates { if index >= maxPlatforms { @@ -132,11 +146,16 @@ candidatesLoop: var candidateErr error for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ { nextAttemptNo := attemptNo + 1 - response, err := s.runCandidate(ctx, task, user, body, candidate, nextAttemptNo, onDelta) + preprocessing := preprocessRequestWithLog(task.Kind, body, candidate) + preprocessingLog := preprocessing.Log + lastPreprocessing = &preprocessingLog + candidateBody := preprocessing.Body + response, err := s.runCandidate(ctx, task, user, candidateBody, preprocessing.Log, candidate, nextAttemptNo, onDelta) if err == nil { attemptNo = nextAttemptNo - billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate)) - record := buildSuccessRecord(task, user, body, candidate, response, billings, isSimulation(task, candidate)) + billings := s.billings(ctx, user, task.Kind, candidateBody, candidate, response, isSimulation(task, candidate)) + record := buildSuccessRecord(task, user, candidateBody, candidate, response, billings, isSimulation(task, candidate)) + record.Metrics = mergeMetrics(record.Metrics, parameterPreprocessingMetrics(preprocessing.Log)) record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics) finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{ TaskID: task.ID, @@ -305,15 +324,20 @@ candidatesLoop: } return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay} } - failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr) + extraMetrics := []map[string]any{} + if lastPreprocessing != nil { + extraMetrics = append(extraMetrics, parameterPreprocessingMetrics(*lastPreprocessing)) + } + failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr, extraMetrics...) if err != nil { return Result{}, err } return Result{Task: failed, Output: failed.Result}, lastErr } -func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) { +func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, preprocessing parameterPreprocessingLog, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) { simulated := isSimulation(task, candidate) + baseAttemptMetrics := mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), parameterPreprocessingMetrics(preprocessing)) reservations := s.rateLimitReservations(ctx, user, candidate, body) limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, "", reservations) if err != nil { @@ -339,18 +363,30 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user Status: "running", Simulated: simulated, RequestSnapshot: body, - Metrics: attemptMetrics(candidate, attemptNo, simulated), + Metrics: baseAttemptMetrics, }) if err != nil { return clients.Response{}, fmt.Errorf("create task attempt: %w", err) } + if err := s.recordTaskParameterPreprocessing(ctx, task.ID, attemptID, attemptNo, candidate, preprocessing); err != nil { + clientErr := &clients.ClientError{Code: "runtime_error", Message: err.Error(), Retryable: false} + _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ + AttemptID: attemptID, + Status: "failed", + Retryable: false, + Metrics: mergeMetrics(baseAttemptMetrics, map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}), + ErrorCode: clients.ErrorCode(clientErr), + ErrorMessage: err.Error(), + }) + return clients.Response{}, fmt.Errorf("record parameter preprocessing: %w", err) + } if err := s.store.AttachRateLimitResultToAttempt(ctx, attemptID, limitResult); err != nil { clientErr := &clients.ClientError{Code: "runtime_error", Message: err.Error(), Retryable: false} _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: false, - Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}), + Metrics: mergeMetrics(baseAttemptMetrics, map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}), ErrorCode: clients.ErrorCode(clientErr), ErrorMessage: err.Error(), }) @@ -371,7 +407,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user AttemptID: attemptID, Status: "failed", Retryable: false, - Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}), + Metrics: mergeMetrics(baseAttemptMetrics, map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}), ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) @@ -425,7 +461,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user responseDurationMS = 0 } } - metrics = mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), metrics) + metrics = mergeMetrics(baseAttemptMetrics, metrics) _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", @@ -444,7 +480,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user } uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result) if err != nil { - metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), map[string]any{ + metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), parameterPreprocessingMetrics(preprocessing), map[string]any{ "error": err.Error(), "retryable": clients.IsRetryable(err), "trace": []any{failureTraceEntry(err, clients.IsRetryable(err))}, @@ -480,7 +516,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user Status: "succeeded", RequestID: response.RequestID, Usage: usageToMap(response.Usage), - Metrics: taskMetrics(task, user, body, candidate, response, simulated), + Metrics: mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), parameterPreprocessingMetrics(preprocessing)), ResponseSnapshot: response.Result, ResponseStartedAt: response.ResponseStartedAt, ResponseFinishedAt: response.ResponseFinishedAt, @@ -491,6 +527,25 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user return response, nil } +func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID string, attemptID string, attemptNo int, candidate store.RuntimeModelCandidate, log parameterPreprocessingLog) error { + _, err := s.store.CreateTaskParamPreprocessingLog(ctx, store.CreateTaskParamPreprocessingLogInput{ + TaskID: taskID, + AttemptID: attemptID, + AttemptNo: attemptNo, + ModelType: log.ModelType, + PlatformID: candidate.PlatformID, + PlatformModelID: candidate.PlatformModelID, + ClientID: candidate.ClientID, + Changed: log.Changed, + ChangeCount: len(log.Changes), + ActualInput: log.Input, + ConvertedOutput: log.Output, + Changes: log.Changes, + ModelSnapshot: log.Model, + }) + return err +} + func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client { if simulated { return s.clients["simulation"] @@ -505,8 +560,12 @@ func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated boo return s.clients["openai"] } -func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool, cause error) (store.GatewayTask, error) { +func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool, cause error, extraMetrics ...map[string]any) (store.GatewayTask, error) { requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(cause, simulated) + if len(extraMetrics) > 0 { + values := append([]map[string]any{metrics}, extraMetrics...) + metrics = mergeMetrics(values...) + } metrics = s.withAttemptHistory(ctx, taskID, metrics) failed, err := s.store.FinishTaskFailure(ctx, store.FinishTaskFailureInput{ TaskID: taskID, @@ -589,6 +648,9 @@ func (s *Service) emit(ctx context.Context, taskID string, eventType string, sta } func modelTypeFromKind(kind string, body map[string]any) string { + if requested := requestedModelTypeFromBody(body); requested != "" { + return requested + } switch kind { case "chat.completions", "responses": return "text_generate" @@ -598,6 +660,9 @@ func modelTypeFromKind(kind string, body map[string]any) string { } return "image_generate" case "videos.generations": + if videoRequestHasVideoOrAudioReference(body) { + return "omni_video" + } if videoRequestHasReferenceImage(body) { return "image_to_video" } @@ -607,6 +672,25 @@ func modelTypeFromKind(kind string, body map[string]any) string { } } +func requestedModelTypeFromBody(body map[string]any) string { + for _, key := range []string{"modelType", "model_type", "capability", "capabilityType"} { + value := strings.TrimSpace(stringFromMap(body, key)) + if isKnownModelType(value) { + return value + } + } + return "" +} + +func isKnownModelType(value string) bool { + switch value { + case "text_generate", "image_generate", "image_edit", "video_generate", "image_to_video", "text_to_video", "video_edit", "omni_video", "omni": + return true + default: + return false + } +} + func videoRequestHasReferenceImage(body map[string]any) bool { if body == nil { return false @@ -622,6 +706,23 @@ func videoRequestHasReferenceImage(body map[string]any) bool { return false } +func videoRequestHasVideoOrAudioReference(body map[string]any) bool { + if body == nil { + return false + } + for _, key := range []string{"video", "video_url", "videoUrl", "reference_video", "referenceVideo", "audio_url", "audioUrl", "reference_audio", "referenceAudio"} { + if hasAnyString(body, key) { + return true + } + } + for _, item := range contentItems(body["content"]) { + if isVideoContent(item) || isAudioContent(item) { + return true + } + } + return false +} + func isTextGenerationKind(kind string) bool { return kind == "chat.completions" || kind == "responses" } @@ -692,10 +793,7 @@ func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool } func normalizeRequest(kind string, body map[string]any) map[string]any { - out := map[string]any{} - for key, value := range body { - out[key] = value - } + out := cloneMap(body) if kind == "responses" && out["messages"] == nil && out["input"] != nil { out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}} } diff --git a/apps/api/internal/store/postgres.go b/apps/api/internal/store/postgres.go index 73b1858..adf0170 100644 --- a/apps/api/internal/store/postgres.go +++ b/apps/api/internal/store/postgres.go @@ -488,6 +488,24 @@ type TaskAttempt struct { FinishedAt string `json:"finishedAt,omitempty"` } +type TaskParamPreprocessingLog struct { + ID string `json:"id"` + TaskID string `json:"taskId"` + AttemptID string `json:"attemptId,omitempty"` + AttemptNo int `json:"attemptNo,omitempty"` + ModelType string `json:"modelType,omitempty"` + PlatformID string `json:"platformId,omitempty"` + PlatformModelID string `json:"platformModelId,omitempty"` + ClientID string `json:"clientId,omitempty"` + Changed bool `json:"changed"` + ChangeCount int `json:"changeCount"` + ActualInput map[string]any `json:"actualInput,omitempty"` + ConvertedOutput map[string]any `json:"convertedOutput,omitempty"` + Changes []any `json:"changes,omitempty"` + ModelSnapshot map[string]any `json:"model,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + func (s *Store) ListPlatforms(ctx context.Context) ([]Platform, error) { rows, err := s.pool.Query(ctx, ` SELECT id::text, provider, platform_key, name, COALESCE(internal_name, ''), COALESCE(base_url, ''), auth_type, status, priority, diff --git a/apps/api/internal/store/runtime_types.go b/apps/api/internal/store/runtime_types.go index 72d4ffc..fd80fab 100644 --- a/apps/api/internal/store/runtime_types.go +++ b/apps/api/internal/store/runtime_types.go @@ -214,3 +214,19 @@ type FinishTaskFailureInput struct { ResponseFinishedAt time.Time ResponseDurationMS int64 } + +type CreateTaskParamPreprocessingLogInput struct { + TaskID string + AttemptID string + AttemptNo int + ModelType string + PlatformID string + PlatformModelID string + ClientID string + Changed bool + ChangeCount int + ActualInput map[string]any + ConvertedOutput map[string]any + Changes any + ModelSnapshot map[string]any +} diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index 2776e4c..80c9670 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -328,6 +328,47 @@ WHERE id = $1::uuid`, input.TaskID, input.AttemptNo); err != nil { return attemptID, tx.Commit(ctx) } +func (s *Store) CreateTaskParamPreprocessingLog(ctx context.Context, input CreateTaskParamPreprocessingLogInput) (string, error) { + actualInputJSON, _ := json.Marshal(emptyObjectIfNil(input.ActualInput)) + convertedOutputJSON, _ := json.Marshal(emptyObjectIfNil(input.ConvertedOutput)) + changesJSON, _ := json.Marshal(input.Changes) + if input.Changes == nil { + changesJSON = []byte("[]") + } + modelSnapshotJSON, _ := json.Marshal(emptyObjectIfNil(input.ModelSnapshot)) + var attemptNo any + if input.AttemptNo > 0 { + attemptNo = input.AttemptNo + } + var id string + err := s.pool.QueryRow(ctx, ` +INSERT INTO gateway_task_param_preprocessing_logs ( + task_id, attempt_id, attempt_no, model_type, platform_id, platform_model_id, client_id, + changed, change_count, actual_input, converted_output, changes, model_snapshot +) +VALUES ( + $1::uuid, NULLIF($2::text, '')::uuid, $3::int, NULLIF($4::text, ''), + NULLIF($5::text, '')::uuid, NULLIF($6::text, '')::uuid, NULLIF($7::text, ''), + $8, $9::int, $10::jsonb, $11::jsonb, $12::jsonb, $13::jsonb +) +RETURNING id::text`, + input.TaskID, + input.AttemptID, + attemptNo, + input.ModelType, + input.PlatformID, + input.PlatformModelID, + input.ClientID, + input.Changed, + input.ChangeCount, + string(actualInputJSON), + string(convertedOutputJSON), + string(changesJSON), + string(modelSnapshotJSON), + ).Scan(&id) + return id, err +} + func (s *Store) attachTaskAttempts(ctx context.Context, items []GatewayTask) ([]GatewayTask, error) { if len(items) == 0 { return items, nil @@ -354,6 +395,31 @@ func (s *Store) ListTaskAttempts(ctx context.Context, taskID string) ([]TaskAtte return attemptsByTaskID[taskID], nil } +func (s *Store) ListTaskParamPreprocessingLogs(ctx context.Context, taskID string) ([]TaskParamPreprocessingLog, error) { + rows, err := s.pool.Query(ctx, ` +SELECT id::text, task_id::text, COALESCE(attempt_id::text, ''), COALESCE(attempt_no, 0), + COALESCE(model_type, ''), COALESCE(platform_id::text, ''), COALESCE(platform_model_id::text, ''), + COALESCE(client_id, ''), changed, change_count, + COALESCE(actual_input, '{}'::jsonb), COALESCE(converted_output, '{}'::jsonb), + COALESCE(changes, '[]'::jsonb), COALESCE(model_snapshot, '{}'::jsonb), created_at +FROM gateway_task_param_preprocessing_logs +WHERE task_id = $1::uuid +ORDER BY COALESCE(attempt_no, 0), created_at`, taskID) + if err != nil { + return nil, err + } + defer rows.Close() + items := make([]TaskParamPreprocessingLog, 0) + for rows.Next() { + item, err := scanTaskParamPreprocessingLog(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, rows.Err() +} + func (s *Store) AppendTaskAttemptTrace(ctx context.Context, taskID string, attemptNo int, entry map[string]any) error { entryJSON, _ := json.Marshal(emptyObjectIfNil(entry)) _, err := s.pool.Exec(ctx, ` @@ -459,6 +525,38 @@ func scanTaskAttempt(scanner taskScanner) (TaskAttempt, error) { return item, nil } +func scanTaskParamPreprocessingLog(scanner taskScanner) (TaskParamPreprocessingLog, error) { + var item TaskParamPreprocessingLog + var actualInputBytes []byte + var convertedOutputBytes []byte + var changesBytes []byte + var modelSnapshotBytes []byte + if err := scanner.Scan( + &item.ID, + &item.TaskID, + &item.AttemptID, + &item.AttemptNo, + &item.ModelType, + &item.PlatformID, + &item.PlatformModelID, + &item.ClientID, + &item.Changed, + &item.ChangeCount, + &actualInputBytes, + &convertedOutputBytes, + &changesBytes, + &modelSnapshotBytes, + &item.CreatedAt, + ); err != nil { + return TaskParamPreprocessingLog{}, err + } + item.ActualInput = decodeObject(actualInputBytes) + item.ConvertedOutput = decodeObject(convertedOutputBytes) + item.Changes = decodeArray(changesBytes) + item.ModelSnapshot = decodeObject(modelSnapshotBytes) + return item, nil +} + func enrichTaskAttemptFromMetrics(item *TaskAttempt) { if item == nil || len(item.Metrics) == 0 { return diff --git a/apps/api/migrations/0033_task_param_preprocessing_logs.sql b/apps/api/migrations/0033_task_param_preprocessing_logs.sql new file mode 100644 index 0000000..6376d7b --- /dev/null +++ b/apps/api/migrations/0033_task_param_preprocessing_logs.sql @@ -0,0 +1,30 @@ +CREATE TABLE IF NOT EXISTS gateway_task_param_preprocessing_logs ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + task_id uuid NOT NULL REFERENCES gateway_tasks(id) ON DELETE CASCADE, + attempt_id uuid REFERENCES gateway_task_attempts(id) ON DELETE SET NULL, + attempt_no integer, + model_type text NOT NULL DEFAULT '', + platform_id uuid REFERENCES integration_platforms(id) ON DELETE SET NULL, + platform_model_id uuid REFERENCES platform_models(id) ON DELETE SET NULL, + client_id text, + changed boolean NOT NULL DEFAULT false, + change_count integer NOT NULL DEFAULT 0, + actual_input jsonb NOT NULL DEFAULT '{}'::jsonb, + converted_output jsonb NOT NULL DEFAULT '{}'::jsonb, + changes jsonb NOT NULL DEFAULT '[]'::jsonb, + model_snapshot jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_gateway_task_param_pre_logs_task + ON gateway_task_param_preprocessing_logs(task_id, created_at); + +CREATE INDEX IF NOT EXISTS idx_gateway_task_param_pre_logs_attempt + ON gateway_task_param_preprocessing_logs(attempt_id) + WHERE attempt_id IS NOT NULL; + +CREATE INDEX IF NOT EXISTS idx_gateway_task_param_pre_logs_changed + ON gateway_task_param_preprocessing_logs(changed, created_at DESC); + +CREATE INDEX IF NOT EXISTS idx_gateway_task_param_pre_logs_changes + ON gateway_task_param_preprocessing_logs USING gin(changes); diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index a93a750..51fe812 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -942,6 +942,7 @@ export function App() { message={coreMessage} section={workspaceSection} state={coreState} + token={token} taskQuery={workspaceTaskQuery} taskTotal={taskTotal} transactionQuery={workspaceTransactionQuery} diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index 7d256f9..20c471f 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -16,6 +16,7 @@ import type { GatewayTenantUpsertRequest, GatewayNetworkProxyConfig, GatewayTask, + GatewayTaskParamPreprocessingLog, GatewayUser, GatewayUserUpsertRequest, GatewayWalletTransaction, @@ -660,6 +661,13 @@ export async function getTask(token: string, taskId: string): Promise(`/api/workspace/tasks/${taskId}`, { token }); } +export async function listTaskParamPreprocessing( + token: string, + taskId: string, +): Promise> { + return request>(`/api/workspace/tasks/${taskId}/param-preprocessing`, { token }); +} + export async function pollTaskUntilSettled( token: string, task: GatewayTask, diff --git a/apps/web/src/pages/WorkspacePage.tsx b/apps/web/src/pages/WorkspacePage.tsx index 750202b..58a7772 100644 --- a/apps/web/src/pages/WorkspacePage.tsx +++ b/apps/web/src/pages/WorkspacePage.tsx @@ -1,12 +1,13 @@ import { useEffect, useMemo, useState, type FormEvent, type ReactNode } from 'react'; import { Popover as AntPopover } from 'antd'; -import { ChevronLeft, ChevronRight, Copy, CreditCard, Eye, KeyRound, ListChecks, Plus, ReceiptText, RotateCcw, Search, ShieldCheck, Trash2, UserRound } from 'lucide-react'; -import type { GatewayAccessRuleBatchRequest, GatewayApiKey, GatewayTask, GatewayWalletAccount, GatewayWalletTransaction, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts'; +import { ChevronLeft, ChevronRight, Copy, CreditCard, Eye, KeyRound, ListChecks, Plus, ReceiptText, RotateCcw, Search, ShieldCheck, SlidersHorizontal, Trash2, UserRound } from 'lucide-react'; +import type { GatewayAccessRuleBatchRequest, GatewayApiKey, GatewayTask, GatewayTaskParamPreprocessingLog, GatewayWalletAccount, GatewayWalletTransaction, IntegrationPlatform, PlatformModel } from '@easyai-ai-gateway/contracts'; import type { ConsoleData } from '../app-state'; import { EntityTable } from '../components/EntityTable'; import { Badge, Button, Card, CardContent, CardHeader, CardTitle, ConfirmDialog, DateTimePicker, DateTimeRangePicker, FormDialog, Input, Label, Select, Table, TableCell, TableFooter, TableHead, TablePageActions, TableRow, TableToolbar, TableViewportLayout, Tabs } from '../components/ui'; import { AccessPermissionEditor, countAccessPermissionRules } from './admin/AccessPermissionEditor'; import type { ApiKeyForm, LoadState, WorkspaceSection, WorkspaceTaskQuery, WorkspaceTransactionQuery } from '../types'; +import { listTaskParamPreprocessing } from '../api'; const tabs = [ { value: 'overview', label: '个人总览', icon: }, @@ -27,6 +28,7 @@ export function WorkspacePage(props: { message: string; section: WorkspaceSection; state: LoadState; + token: string; taskQuery: WorkspaceTaskQuery; taskTotal: number; transactionQuery: WorkspaceTransactionQuery; @@ -48,7 +50,7 @@ export function WorkspacePage(props: { {props.section === 'overview' && } {props.section === 'billing' && } {props.section === 'apiKeys' && } - {props.section === 'tasks' && } + {props.section === 'tasks' && } {props.section === 'transactions' && } @@ -545,6 +547,7 @@ function ApiKeyPanel(props: { function TaskPanel(props: { data: ConsoleData; query: WorkspaceTaskQuery; + token: string; total: number; onQueryChange: (value: WorkspaceTaskQuery) => void; }) { @@ -675,6 +678,7 @@ function TaskPanel(props: { 状态 模型 尝试链路 + 参数转换 类型 API Key Token @@ -684,7 +688,7 @@ function TaskPanel(props: { 原始 JSON {tasks.map((task) => ( - + ))} ) : ( @@ -759,7 +763,7 @@ function TaskPanel(props: { ); } -function TaskRecord(props: { task: GatewayTask; onCopyRequestId: (task: GatewayTask) => Promise; onOpenJson: (task: GatewayTask) => void }) { +function TaskRecord(props: { task: GatewayTask; token: string; onCopyRequestId: (task: GatewayTask) => Promise; onOpenJson: (task: GatewayTask) => void }) { const usage = props.task.usage ?? {}; const tokenUsage = formatTokenUsage(usage); const chargeText = props.task.finalChargeAmount !== undefined ? formatCellValue(props.task.finalChargeAmount) : '-'; @@ -801,6 +805,9 @@ function TaskRecord(props: { task: GatewayTask; onCopyRequestId: (task: GatewayT + + + {props.task.modelType || '-'} {props.task.apiKeyName || props.task.apiKeyPrefix || props.task.apiKeyId || '-'} {tokenUsage} @@ -817,6 +824,108 @@ function TaskRecord(props: { task: GatewayTask; onCopyRequestId: (task: GatewayT ); } +type TaskParamConversionSummary = { + changed: boolean; + changeCount: number; + actions: string[]; + paths: string[]; + capabilityPaths: string[]; +}; + +function TaskParamConversionCell(props: { task: GatewayTask; token: string }) { + const summary = taskParamConversionSummary(props.task); + const [open, setOpen] = useState(false); + const [logs, setLogs] = useState(null); + const [loadState, setLoadState] = useState<'idle' | 'loading' | 'ready' | 'error'>('idle'); + const [error, setError] = useState(''); + + useEffect(() => { + if (!open || !summary.changed || logs || loadState === 'loading' || !props.token) return; + let cancelled = false; + setLoadState('loading'); + setError(''); + listTaskParamPreprocessing(props.token, props.task.id) + .then((response) => { + if (cancelled) return; + setLogs(response.items ?? []); + setLoadState('ready'); + }) + .catch((err) => { + if (cancelled) return; + setError(err instanceof Error ? err.message : '参数转换明细加载失败'); + setLoadState('error'); + }); + return () => { + cancelled = true; + }; + }, [loadState, logs, open, props.task.id, props.token, summary.changed]); + + if (!summary.changed) { + return 无转换; + } + + return ( + } + overlayClassName="taskParamConversionAntPopover" + placement="bottomLeft" + trigger={['hover', 'focus']} + onOpenChange={setOpen} + > + + + ); +} + +function TaskParamConversionPopover(props: { + error: string; + loadState: 'idle' | 'loading' | 'ready' | 'error'; + logs: GatewayTaskParamPreprocessingLog[] | null; + summary: TaskParamConversionSummary; +}) { + const logs = props.logs ?? []; + return ( + + + 参数转换汇总 + {taskParamSummaryText(props.summary)} + + {props.loadState === 'loading' && 正在加载转换明细...} + {props.loadState === 'error' && {props.error || '参数转换明细加载失败'}} + {props.loadState === 'ready' && logs.length === 0 && 暂无转换明细。} + {logs.map((log) => ( + + + {taskParamLogTitle(log)} + {log.changeCount} 项 + + {(log.changes ?? []).slice(0, 8).map((change, index) => ( + + + {taskParamActionLabel(objectString(change, 'action'))} + {objectString(change, 'path') || '-'} + + {objectString(change, 'reason') || '按模型能力配置调整参数。'} + {objectString(change, 'capabilityPath') && 能力配置:{objectString(change, 'capabilityPath')}} + {taskParamChangePreview(change)} + + ))} + {(log.changes?.length ?? 0) > 8 && 还有 {(log.changes?.length ?? 0) - 8} 项转换未展开。} + + ))} + {props.loadState !== 'ready' && props.summary.capabilityPaths.length > 0 && ( + + {props.summary.capabilityPaths.slice(0, 4).map((path) => {path})} + + )} + + ); +} + function TaskAttemptChain(props: { task: GatewayTask }) { const attempts = props.task.attempts ?? []; if (!attempts.length) return -; @@ -999,6 +1108,73 @@ function taskAttemptTraceReasonLabel(reason: string) { return labels[reason] ?? reason; } +function taskParamConversionSummary(task: GatewayTask): TaskParamConversionSummary { + const summary: TaskParamConversionSummary = { + changed: false, + changeCount: 0, + actions: [], + paths: [], + capabilityPaths: [], + }; + mergeTaskParamSummary(summary, metadataObject(task.metrics, 'parameterPreprocessingSummary')); + for (const attempt of task.attempts ?? []) { + mergeTaskParamSummary(summary, metadataObject(attempt.metrics, 'parameterPreprocessingSummary')); + } + return summary; +} + +function mergeTaskParamSummary(target: TaskParamConversionSummary, raw: Record) { + if (!Object.keys(raw).length) return; + target.changed = target.changed || raw.changed === true; + const changeCount = metadataNumber(raw, 'changeCount'); + if (changeCount) target.changeCount += Math.max(0, Math.trunc(changeCount)); + for (const action of metadataStringList(raw, 'actions')) appendUniqueText(target.actions, action); + for (const path of metadataStringList(raw, 'paths')) appendUniqueText(target.paths, path); + for (const path of metadataStringList(raw, 'capabilityPaths')) appendUniqueText(target.capabilityPaths, path); +} + +function taskParamSummaryText(summary: TaskParamConversionSummary) { + const actionText = summary.actions.map(taskParamActionLabel).join('、'); + const parts = [ + `${summary.changeCount || summary.paths.length || 1} 项转换`, + actionText ? `动作 ${actionText}` : '', + summary.capabilityPaths.length ? `涉及 ${summary.capabilityPaths.length} 项能力配置` : '', + ].filter(Boolean); + return parts.join(' · '); +} + +function taskParamLogTitle(log: GatewayTaskParamPreprocessingLog) { + const parts = [ + log.attemptNo ? `#${log.attemptNo}` : '', + log.modelType || '', + log.clientId || '', + ].filter(Boolean); + return parts.join(' · ') || '预处理记录'; +} + +function taskParamActionLabel(action: string) { + if (action === 'remove') return '移除'; + if (action === 'adjust') return '调整'; + if (action === 'set') return '补齐'; + return action || '转换'; +} + +function taskParamChangePreview(change: Record) { + const before = previewCompactValue(change.before); + const after = previewCompactValue(change.after); + const action = objectString(change, 'action'); + if (action === 'remove') return `原值 ${before}`; + if (action === 'set') return `新值 ${after}`; + return `原值 ${before} -> 新值 ${after}`; +} + +function previewCompactValue(value: unknown) { + if (value === undefined || value === null || value === '') return '-'; + const text = typeof value === 'string' ? value : JSON.stringify(value); + if (!text) return '-'; + return text.length > 150 ? `${text.slice(0, 150)}...` : text; +} + function formatCellValue(value: unknown) { if (value === undefined || value === null || value === '') return '-'; return String(value); @@ -1039,6 +1215,12 @@ function metadataString(metadata: Record | undefined, key: stri return typeof value === 'string' && value.trim() ? value.trim() : ''; } +function metadataStringList(metadata: Record | undefined, key: string) { + const value = metadata?.[key]; + if (!Array.isArray(value)) return []; + return value.filter((item): item is string => typeof item === 'string' && item.trim() !== '').map((item) => item.trim()); +} + function metadataNumber(metadata: Record | undefined, key: string) { const value = metadata?.[key]; if (value === undefined || value === null || value === '') return null; @@ -1057,6 +1239,11 @@ function objectString(value: Record, key: string) { return typeof next === 'string' && next.trim() ? next.trim() : ''; } +function appendUniqueText(values: string[], value: string) { + if (!value || values.includes(value)) return; + values.push(value); +} + function transactionChargeAmount(transaction: GatewayWalletTransaction) { return metadataNumber(transaction.metadata, 'finalChargeAmount') ?? transaction.amount; } diff --git a/apps/web/src/styles.css b/apps/web/src/styles.css index 8f6c273..d29b493 100644 --- a/apps/web/src/styles.css +++ b/apps/web/src/styles.css @@ -261,8 +261,8 @@ strong { } .taskRecordTable .shTableRow { - grid-template-columns: minmax(190px, 0.9fr) minmax(220px, 1fr) minmax(94px, 0.4fr) minmax(280px, 1.45fr) minmax(104px, 0.42fr) minmax(126px, 0.55fr) minmax(150px, 0.66fr) minmax(154px, 0.62fr) minmax(82px, 0.36fr) minmax(98px, 0.42fr) minmax(150px, 0.66fr) minmax(130px, 0.54fr); - min-width: 1778px; + grid-template-columns: minmax(190px, 0.9fr) minmax(220px, 1fr) minmax(94px, 0.4fr) minmax(280px, 1.45fr) minmax(104px, 0.42fr) minmax(118px, 0.44fr) minmax(126px, 0.55fr) minmax(150px, 0.66fr) minmax(154px, 0.62fr) minmax(82px, 0.36fr) minmax(98px, 0.42fr) minmax(150px, 0.66fr) minmax(130px, 0.54fr); + min-width: 1904px; align-items: start; } @@ -359,6 +359,11 @@ strong { white-space: normal; } +.taskRecordParamCell { + overflow: visible; + white-space: normal; +} + .taskRecordAttemptCount { display: inline-flex; align-items: center; @@ -437,6 +442,113 @@ strong { overflow-wrap: anywhere; } +.taskParamConversionEmpty { + color: var(--text-soft); + font-size: var(--font-size-xs); +} + +.taskParamConversionTrigger { + display: inline-flex; + align-items: center; + gap: 0.35rem; + min-height: 1.5rem; + padding: 0; + border: 0; + background: transparent; + color: var(--text-strong); + cursor: default; + font: inherit; + font-weight: var(--font-weight-medium); +} + +.taskParamConversionTrigger svg { + color: var(--primary); +} + +.taskParamConversionAntPopover { + z-index: 1200; +} + +.taskParamConversionPopover { + display: grid; + width: min(42rem, calc(100vw - 2rem)); + max-height: min(34rem, calc(100vh - 7rem)); + overflow-y: auto; + gap: 0.65rem; +} + +.taskParamConversionPopoverHeader, +.taskParamConversionLogHeader, +.taskParamConversionChangeTop { + display: flex; + min-width: 0; + align-items: center; + justify-content: space-between; + gap: 0.5rem; +} + +.taskParamConversionPopoverHeader strong, +.taskParamConversionLogHeader strong, +.taskParamConversionChangeTop strong { + min-width: 0; + color: var(--text-strong); + font-weight: var(--font-weight-semibold); + overflow-wrap: anywhere; +} + +.taskParamConversionPopoverHeader small, +.taskParamConversionLog small, +.taskParamConversionChange small { + color: var(--text-soft); + font-size: var(--font-size-xs); + line-height: 1.4; +} + +.taskParamConversionState { + color: var(--text-soft); + font-size: var(--font-size-sm); +} + +.taskParamConversionState.error { + color: var(--destructive); +} + +.taskParamConversionLog { + display: grid; + min-width: 0; + gap: 0.5rem; + padding-top: 0.65rem; + border-top: 1px solid var(--border); +} + +.taskParamConversionChange { + display: grid; + gap: 0.28rem; + padding: 0.5rem; + border: 1px solid var(--border-subtle); + border-radius: var(--radius-sm); + background: var(--surface-subtle); + color: var(--text-normal); + font-size: var(--font-size-xs); + line-height: 1.45; + overflow-wrap: anywhere; +} + +.taskParamConversionChange code, +.taskParamConversionSummaryPaths code { + color: var(--text-soft); + font-family: var(--font-mono); + font-size: var(--font-size-xs); + overflow-wrap: anywhere; + white-space: normal; +} + +.taskParamConversionSummaryPaths { + display: flex; + flex-wrap: wrap; + gap: 0.35rem; +} + .taskRecordJsonButton { width: 100%; justify-content: flex-start; diff --git a/packages/contracts/src/index.ts b/packages/contracts/src/index.ts index fdf801d..7802fb2 100644 --- a/packages/contracts/src/index.ts +++ b/packages/contracts/src/index.ts @@ -829,6 +829,24 @@ export interface GatewayTask { updatedAt: string; } +export interface GatewayTaskParamPreprocessingLog { + id: string; + taskId: string; + attemptId?: string; + attemptNo?: number; + modelType?: string; + platformId?: string; + platformModelId?: string; + clientId?: string; + changed: boolean; + changeCount: number; + actualInput?: Record; + convertedOutput?: Record; + changes?: Array>; + model?: Record; + createdAt: string; +} + export interface GatewayTaskAttempt { id: string; taskId: string;