diff --git a/apps/api/internal/runner/param_processor.go b/apps/api/internal/runner/param_processor.go index 3bf9550..8957575 100644 --- a/apps/api/internal/runner/param_processor.go +++ b/apps/api/internal/runner/param_processor.go @@ -11,6 +11,7 @@ import ( ) type paramProcessContext struct { + kind string modelCapability map[string]any candidate store.RuntimeModelCandidate log *parameterPreprocessingLog @@ -58,6 +59,7 @@ func NewParamProcessorChain() ParamProcessorChain { processors: []paramProcessor{ resolutionNormalizeProcessor{}, aspectRatioProcessor{}, + messageContentProcessor{}, contentFilterProcessor{}, inputAudioProcessor{}, durationProcessor{}, @@ -91,6 +93,7 @@ func preprocessRequestWithLog(kind string, body map[string]any, candidate store. }, } context := ¶mProcessContext{ + kind: kind, modelCapability: effectiveModelCapability(candidate), candidate: candidate, log: &log, @@ -333,6 +336,193 @@ func (aspectRatioProcessor) Process(params map[string]any, modelType string, con return true } +type messageContentProcessor struct{} + +func (messageContentProcessor) Name() string { return "MessageContentProcessor" } + +func (messageContentProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return isTextGenerationKind(context.kind) && params["messages"] != nil +} + +func (messageContentProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + messages, changed := processMessageListContent(params["messages"], context) + if changed { + params["messages"] = messages + } + return true +} + +func processMessageListContent(value any, context *paramProcessContext) ([]any, bool) { + rawMessages, ok := value.([]any) + if !ok { + return nil, false + } + out := make([]any, 0, len(rawMessages)) + changed := false + for messageIndex, rawMessage := range rawMessages { + message, ok := rawMessage.(map[string]any) + if !ok { + out = append(out, rawMessage) + continue + } + nextMessage := cloneMap(message) + if contentParts, ok := message["content"].([]any); ok { + nextContent, contentChanged := processMessageContentParts( + contentParts, + fmt.Sprintf("messages[%d].content", messageIndex), + context, + ) + if contentChanged { + nextMessage["content"] = nextContent + changed = true + } + } + out = append(out, nextMessage) + } + return out, changed +} + +func processMessageContentParts(parts []any, basePath string, context *paramProcessContext) ([]any, bool) { + out := make([]any, 0, len(parts)) + changed := false + for partIndex, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + out = append(out, rawPart) + continue + } + if replacement, replacementChanged := messageContentPartReplacement(part, context); replacementChanged { + out = append(out, replacement) + context.recordChange( + "MessageContentProcessor", + "convert", + fmt.Sprintf("%s[%d]", basePath, partIndex), + part, + replacement, + messageContentConversionReason(part), + messageContentCapabilityPath(part), + messageContentCapabilityValue(part, context), + ) + changed = true + continue + } + out = append(out, cloneMap(part)) + } + return out, changed +} + +func messageContentPartReplacement(part map[string]any, context *paramProcessContext) (map[string]any, bool) { + switch { + case isImageContent(part): + if modelSupportsMessageModality(context, "image_analysis") { + return nil, false + } + if url := imageURLFromContentPart(part); url != "" { + return map[string]any{"type": "text", "text": "Image link: " + url}, true + } + case isVideoContent(part): + if modelSupportsMessageModality(context, "video_understanding") { + return nil, false + } + if url := videoURLFromContentPart(part); url != "" { + return map[string]any{"type": "text", "text": "video URL: " + url}, true + } + case isAudioContent(part) || stringFromAny(part["type"]) == "input_audio": + if modelSupportsMessageModality(context, "audio_understanding") { + return nil, false + } + if url := audioURLFromContentPart(part); url != "" { + return map[string]any{"type": "text", "text": "audio URL: " + url}, true + } + } + return nil, false +} + +func messageContentConversionReason(part map[string]any) string { + switch { + case isImageContent(part): + return "模型不支持图像理解,已将 image_url 转为文本链接。" + case isVideoContent(part): + return "模型不支持视频理解,已将 video_url 转为文本链接。" + default: + return "模型不支持音频理解,已将音频输入转为文本链接。" + } +} + +func messageContentCapabilityPath(part map[string]any) string { + switch { + case isImageContent(part): + return "capabilities.image_analysis" + case isVideoContent(part): + return "capabilities.video_understanding" + default: + return "capabilities.audio_understanding" + } +} + +func messageContentCapabilityValue(part map[string]any, context *paramProcessContext) any { + if context == nil { + return nil + } + switch { + case isImageContent(part): + return capabilityValue(context.modelCapability, "image_analysis", "") + case isVideoContent(part): + return capabilityValue(context.modelCapability, "video_understanding", "") + default: + return capabilityValue(context.modelCapability, "audio_understanding", "") + } +} + +func modelSupportsMessageModality(context *paramProcessContext, capabilityName string) bool { + if context == nil { + return false + } + capabilities := context.modelCapability + if capabilityForType(capabilities, capabilityName) != nil { + return true + } + if capabilityForType(capabilities, "omni") != nil { + return true + } + originalTypes := stringListFromAny(capabilities["originalTypes"]) + return containsString(originalTypes, capabilityName) || containsString(originalTypes, "omni") +} + +func imageURLFromContentPart(part map[string]any) string { + return urlFromNestedContentPart(part, "image_url", "url", "imageUrl") +} + +func videoURLFromContentPart(part map[string]any) string { + return urlFromNestedContentPart(part, "video_url", "url", "videoUrl") +} + +func audioURLFromContentPart(part map[string]any) string { + if stringFromAny(part["type"]) == "input_audio" { + if audio, ok := part["input_audio"].(map[string]any); ok { + if url := firstNonEmptyString(stringFromAny(audio["data"]), stringFromAny(audio["url"])); url != "" { + return url + } + } + } + return urlFromNestedContentPart(part, "audio_url", "url", "audioUrl") +} + +func urlFromNestedContentPart(part map[string]any, keys ...string) string { + for _, key := range keys { + value := part[key] + if url := stringFromAny(value); url != "" { + return url + } + if nested, ok := value.(map[string]any); ok { + if url := stringFromAny(nested["url"]); url != "" { + return url + } + } + } + return "" +} + type contentFilterProcessor struct{} func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" } @@ -1239,6 +1429,9 @@ func capabilityValue(capabilities map[string]any, modelType string, key string) if capability == nil { return nil } + if strings.TrimSpace(key) == "" { + return cloneMap(capability) + } return cloneAny(capability[key]) } diff --git a/apps/api/internal/runner/param_processor_test.go b/apps/api/internal/runner/param_processor_test.go index 34ab9f8..f4130ae 100644 --- a/apps/api/internal/runner/param_processor_test.go +++ b/apps/api/internal/runner/param_processor_test.go @@ -123,6 +123,163 @@ func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) { t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes) } +func TestParamProcessorChatConvertsUnsupportedMediaMessageContentToText(t *testing.T) { + body := map[string]any{ + "model": "text-only", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe these"}, + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}}, + map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}}, + map[string]any{"type": "input_audio", "input_audio": map[string]any{"data": "https://example.com/input.wav"}}, + }, + }, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "text_generate", + Capabilities: map[string]any{ + "text_generate": map[string]any{}, + "originalTypes": []any{"text_generate"}, + }, + } + + result := preprocessRequestWithLog("chat.completions", body, candidate) + messages, _ := result.Body["messages"].([]any) + if len(messages) != 1 { + t.Fatalf("expected one message, got %+v", result.Body["messages"]) + } + message, _ := messages[0].(map[string]any) + content, _ := message["content"].([]any) + if len(content) != 5 { + t.Fatalf("expected five content parts, got %+v", message["content"]) + } + expectedText := []string{ + "describe these", + "Image link: https://example.com/image.png", + "video URL: https://example.com/video.mp4", + "audio URL: https://example.com/audio.mp3", + "audio URL: https://example.com/input.wav", + } + for index, expected := range expectedText { + part, _ := content[index].(map[string]any) + if stringFromAny(part["text"]) != expected { + t.Fatalf("content[%d] text = %q, want %q; all=%+v", index, stringFromAny(part["text"]), expected, content) + } + } + if len(result.Log.Changes) != 4 { + t.Fatalf("expected four media conversion changes, got %+v", result.Log.Changes) + } + expectedCapabilityPaths := map[string]bool{ + "capabilities.image_analysis": false, + "capabilities.video_understanding": false, + "capabilities.audio_understanding": false, + } + for _, change := range result.Log.Changes { + if _, ok := expectedCapabilityPaths[change.CapabilityPath]; ok { + expectedCapabilityPaths[change.CapabilityPath] = true + } + } + for path, found := range expectedCapabilityPaths { + if !found { + t.Fatalf("expected conversion log for %s, got %+v", path, result.Log.Changes) + } + } +} + +func TestParamProcessorChatKeepsOmniMessageContent(t *testing.T) { + body := map[string]any{ + "model": "omni", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}}, + map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}}, + }, + }, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "text_generate", + Capabilities: map[string]any{ + "text_generate": map[string]any{}, + "omni": map[string]any{}, + "originalTypes": []any{"text_generate", "omni"}, + }, + } + + result := preprocessRequestWithLog("chat.completions", body, candidate) + if result.Log.Changed { + t.Fatalf("omni model should keep message media content unchanged, got %+v", result.Log.Changes) + } + messages, _ := result.Body["messages"].([]any) + message, _ := messages[0].(map[string]any) + content, _ := message["content"].([]any) + for _, item := range content { + part, _ := item.(map[string]any) + if stringFromAny(part["type"]) == "text" { + t.Fatalf("media content should not be converted for omni model: %+v", content) + } + } +} + +func TestParamProcessorChatConvertsOnlyUnsupportedModalities(t *testing.T) { + body := map[string]any{ + "model": "vision-only", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}}, + }, + }, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "text_generate", + Capabilities: map[string]any{ + "text_generate": map[string]any{}, + "image_analysis": map[string]any{}, + "originalTypes": []any{"text_generate", "image_analysis"}, + }, + } + + result := preprocessRequestWithLog("chat.completions", body, candidate) + messages, _ := result.Body["messages"].([]any) + message, _ := messages[0].(map[string]any) + content, _ := message["content"].([]any) + first, _ := content[0].(map[string]any) + second, _ := content[1].(map[string]any) + if stringFromAny(first["type"]) != "image_url" { + t.Fatalf("image content should be kept when image_analysis is supported: %+v", content) + } + if stringFromAny(second["text"]) != "video URL: https://example.com/video.mp4" { + t.Fatalf("video content should be converted, got %+v", second) + } + if len(result.Log.Changes) != 1 || result.Log.Changes[0].CapabilityPath != "capabilities.video_understanding" { + t.Fatalf("expected only video conversion to be logged, got %+v", result.Log.Changes) + } +} + +func TestSkipTaskParameterPreprocessingLogForTextModelTypes(t *testing.T) { + for _, modelType := range []string{"text_generate", "chat", "responses", "text"} { + if !skipTaskParameterPreprocessingLog(modelType) { + t.Fatalf("%s should skip task parameter preprocessing log", modelType) + } + } + for _, modelType := range []string{"image_generate", "image_edit", "video_generate", "omni_video"} { + if skipTaskParameterPreprocessingLog(modelType) { + t.Fatalf("%s should keep task parameter preprocessing log", modelType) + } + } +} + func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) { body := map[string]any{ "model": "Seedance", diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index 116da1a..e5cc2e8 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -531,6 +531,9 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user } func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID string, attemptID string, attemptNo int, candidate store.RuntimeModelCandidate, log parameterPreprocessingLog) error { + if skipTaskParameterPreprocessingLog(log.ModelType) { + return nil + } _, err := s.store.CreateTaskParamPreprocessingLog(ctx, store.CreateTaskParamPreprocessingLogInput{ TaskID: taskID, AttemptID: attemptID, @@ -549,6 +552,15 @@ func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID s return err } +func skipTaskParameterPreprocessingLog(modelType string) bool { + switch strings.TrimSpace(modelType) { + case "text_generate", "chat", "responses", "text": + return true + default: + return false + } +} + func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client { if simulated { return s.clients["simulation"]