fix: preprocess message content by model capability

This commit is contained in:
wangbo 2026-05-13 23:39:11 +08:00
parent b8a716169f
commit f1535a94c2
3 changed files with 362 additions and 0 deletions

View File

@ -11,6 +11,7 @@ import (
)
type paramProcessContext struct {
kind string
modelCapability map[string]any
candidate store.RuntimeModelCandidate
log *parameterPreprocessingLog
@ -58,6 +59,7 @@ func NewParamProcessorChain() ParamProcessorChain {
processors: []paramProcessor{
resolutionNormalizeProcessor{},
aspectRatioProcessor{},
messageContentProcessor{},
contentFilterProcessor{},
inputAudioProcessor{},
durationProcessor{},
@ -91,6 +93,7 @@ func preprocessRequestWithLog(kind string, body map[string]any, candidate store.
},
}
context := &paramProcessContext{
kind: kind,
modelCapability: effectiveModelCapability(candidate),
candidate: candidate,
log: &log,
@ -333,6 +336,193 @@ func (aspectRatioProcessor) Process(params map[string]any, modelType string, con
return true
}
type messageContentProcessor struct{}
func (messageContentProcessor) Name() string { return "MessageContentProcessor" }
func (messageContentProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
return isTextGenerationKind(context.kind) && params["messages"] != nil
}
func (messageContentProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
messages, changed := processMessageListContent(params["messages"], context)
if changed {
params["messages"] = messages
}
return true
}
func processMessageListContent(value any, context *paramProcessContext) ([]any, bool) {
rawMessages, ok := value.([]any)
if !ok {
return nil, false
}
out := make([]any, 0, len(rawMessages))
changed := false
for messageIndex, rawMessage := range rawMessages {
message, ok := rawMessage.(map[string]any)
if !ok {
out = append(out, rawMessage)
continue
}
nextMessage := cloneMap(message)
if contentParts, ok := message["content"].([]any); ok {
nextContent, contentChanged := processMessageContentParts(
contentParts,
fmt.Sprintf("messages[%d].content", messageIndex),
context,
)
if contentChanged {
nextMessage["content"] = nextContent
changed = true
}
}
out = append(out, nextMessage)
}
return out, changed
}
func processMessageContentParts(parts []any, basePath string, context *paramProcessContext) ([]any, bool) {
out := make([]any, 0, len(parts))
changed := false
for partIndex, rawPart := range parts {
part, ok := rawPart.(map[string]any)
if !ok {
out = append(out, rawPart)
continue
}
if replacement, replacementChanged := messageContentPartReplacement(part, context); replacementChanged {
out = append(out, replacement)
context.recordChange(
"MessageContentProcessor",
"convert",
fmt.Sprintf("%s[%d]", basePath, partIndex),
part,
replacement,
messageContentConversionReason(part),
messageContentCapabilityPath(part),
messageContentCapabilityValue(part, context),
)
changed = true
continue
}
out = append(out, cloneMap(part))
}
return out, changed
}
func messageContentPartReplacement(part map[string]any, context *paramProcessContext) (map[string]any, bool) {
switch {
case isImageContent(part):
if modelSupportsMessageModality(context, "image_analysis") {
return nil, false
}
if url := imageURLFromContentPart(part); url != "" {
return map[string]any{"type": "text", "text": "Image link: " + url}, true
}
case isVideoContent(part):
if modelSupportsMessageModality(context, "video_understanding") {
return nil, false
}
if url := videoURLFromContentPart(part); url != "" {
return map[string]any{"type": "text", "text": "video URL: " + url}, true
}
case isAudioContent(part) || stringFromAny(part["type"]) == "input_audio":
if modelSupportsMessageModality(context, "audio_understanding") {
return nil, false
}
if url := audioURLFromContentPart(part); url != "" {
return map[string]any{"type": "text", "text": "audio URL: " + url}, true
}
}
return nil, false
}
func messageContentConversionReason(part map[string]any) string {
switch {
case isImageContent(part):
return "模型不支持图像理解,已将 image_url 转为文本链接。"
case isVideoContent(part):
return "模型不支持视频理解,已将 video_url 转为文本链接。"
default:
return "模型不支持音频理解,已将音频输入转为文本链接。"
}
}
func messageContentCapabilityPath(part map[string]any) string {
switch {
case isImageContent(part):
return "capabilities.image_analysis"
case isVideoContent(part):
return "capabilities.video_understanding"
default:
return "capabilities.audio_understanding"
}
}
func messageContentCapabilityValue(part map[string]any, context *paramProcessContext) any {
if context == nil {
return nil
}
switch {
case isImageContent(part):
return capabilityValue(context.modelCapability, "image_analysis", "")
case isVideoContent(part):
return capabilityValue(context.modelCapability, "video_understanding", "")
default:
return capabilityValue(context.modelCapability, "audio_understanding", "")
}
}
func modelSupportsMessageModality(context *paramProcessContext, capabilityName string) bool {
if context == nil {
return false
}
capabilities := context.modelCapability
if capabilityForType(capabilities, capabilityName) != nil {
return true
}
if capabilityForType(capabilities, "omni") != nil {
return true
}
originalTypes := stringListFromAny(capabilities["originalTypes"])
return containsString(originalTypes, capabilityName) || containsString(originalTypes, "omni")
}
func imageURLFromContentPart(part map[string]any) string {
return urlFromNestedContentPart(part, "image_url", "url", "imageUrl")
}
func videoURLFromContentPart(part map[string]any) string {
return urlFromNestedContentPart(part, "video_url", "url", "videoUrl")
}
func audioURLFromContentPart(part map[string]any) string {
if stringFromAny(part["type"]) == "input_audio" {
if audio, ok := part["input_audio"].(map[string]any); ok {
if url := firstNonEmptyString(stringFromAny(audio["data"]), stringFromAny(audio["url"])); url != "" {
return url
}
}
}
return urlFromNestedContentPart(part, "audio_url", "url", "audioUrl")
}
func urlFromNestedContentPart(part map[string]any, keys ...string) string {
for _, key := range keys {
value := part[key]
if url := stringFromAny(value); url != "" {
return url
}
if nested, ok := value.(map[string]any); ok {
if url := stringFromAny(nested["url"]); url != "" {
return url
}
}
}
return ""
}
type contentFilterProcessor struct{}
func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" }
@ -1239,6 +1429,9 @@ func capabilityValue(capabilities map[string]any, modelType string, key string)
if capability == nil {
return nil
}
if strings.TrimSpace(key) == "" {
return cloneMap(capability)
}
return cloneAny(capability[key])
}

View File

@ -123,6 +123,163 @@ func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) {
t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes)
}
func TestParamProcessorChatConvertsUnsupportedMediaMessageContentToText(t *testing.T) {
body := map[string]any{
"model": "text-only",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe these"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}},
map[string]any{"type": "input_audio", "input_audio": map[string]any{"data": "https://example.com/input.wav"}},
},
},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "text_generate",
Capabilities: map[string]any{
"text_generate": map[string]any{},
"originalTypes": []any{"text_generate"},
},
}
result := preprocessRequestWithLog("chat.completions", body, candidate)
messages, _ := result.Body["messages"].([]any)
if len(messages) != 1 {
t.Fatalf("expected one message, got %+v", result.Body["messages"])
}
message, _ := messages[0].(map[string]any)
content, _ := message["content"].([]any)
if len(content) != 5 {
t.Fatalf("expected five content parts, got %+v", message["content"])
}
expectedText := []string{
"describe these",
"Image link: https://example.com/image.png",
"video URL: https://example.com/video.mp4",
"audio URL: https://example.com/audio.mp3",
"audio URL: https://example.com/input.wav",
}
for index, expected := range expectedText {
part, _ := content[index].(map[string]any)
if stringFromAny(part["text"]) != expected {
t.Fatalf("content[%d] text = %q, want %q; all=%+v", index, stringFromAny(part["text"]), expected, content)
}
}
if len(result.Log.Changes) != 4 {
t.Fatalf("expected four media conversion changes, got %+v", result.Log.Changes)
}
expectedCapabilityPaths := map[string]bool{
"capabilities.image_analysis": false,
"capabilities.video_understanding": false,
"capabilities.audio_understanding": false,
}
for _, change := range result.Log.Changes {
if _, ok := expectedCapabilityPaths[change.CapabilityPath]; ok {
expectedCapabilityPaths[change.CapabilityPath] = true
}
}
for path, found := range expectedCapabilityPaths {
if !found {
t.Fatalf("expected conversion log for %s, got %+v", path, result.Log.Changes)
}
}
}
func TestParamProcessorChatKeepsOmniMessageContent(t *testing.T) {
body := map[string]any{
"model": "omni",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}},
},
},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "text_generate",
Capabilities: map[string]any{
"text_generate": map[string]any{},
"omni": map[string]any{},
"originalTypes": []any{"text_generate", "omni"},
},
}
result := preprocessRequestWithLog("chat.completions", body, candidate)
if result.Log.Changed {
t.Fatalf("omni model should keep message media content unchanged, got %+v", result.Log.Changes)
}
messages, _ := result.Body["messages"].([]any)
message, _ := messages[0].(map[string]any)
content, _ := message["content"].([]any)
for _, item := range content {
part, _ := item.(map[string]any)
if stringFromAny(part["type"]) == "text" {
t.Fatalf("media content should not be converted for omni model: %+v", content)
}
}
}
func TestParamProcessorChatConvertsOnlyUnsupportedModalities(t *testing.T) {
body := map[string]any{
"model": "vision-only",
"messages": []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
},
},
},
}
candidate := store.RuntimeModelCandidate{
ModelType: "text_generate",
Capabilities: map[string]any{
"text_generate": map[string]any{},
"image_analysis": map[string]any{},
"originalTypes": []any{"text_generate", "image_analysis"},
},
}
result := preprocessRequestWithLog("chat.completions", body, candidate)
messages, _ := result.Body["messages"].([]any)
message, _ := messages[0].(map[string]any)
content, _ := message["content"].([]any)
first, _ := content[0].(map[string]any)
second, _ := content[1].(map[string]any)
if stringFromAny(first["type"]) != "image_url" {
t.Fatalf("image content should be kept when image_analysis is supported: %+v", content)
}
if stringFromAny(second["text"]) != "video URL: https://example.com/video.mp4" {
t.Fatalf("video content should be converted, got %+v", second)
}
if len(result.Log.Changes) != 1 || result.Log.Changes[0].CapabilityPath != "capabilities.video_understanding" {
t.Fatalf("expected only video conversion to be logged, got %+v", result.Log.Changes)
}
}
func TestSkipTaskParameterPreprocessingLogForTextModelTypes(t *testing.T) {
for _, modelType := range []string{"text_generate", "chat", "responses", "text"} {
if !skipTaskParameterPreprocessingLog(modelType) {
t.Fatalf("%s should skip task parameter preprocessing log", modelType)
}
}
for _, modelType := range []string{"image_generate", "image_edit", "video_generate", "omni_video"} {
if skipTaskParameterPreprocessingLog(modelType) {
t.Fatalf("%s should keep task parameter preprocessing log", modelType)
}
}
}
func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) {
body := map[string]any{
"model": "Seedance",

View File

@ -531,6 +531,9 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
}
func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID string, attemptID string, attemptNo int, candidate store.RuntimeModelCandidate, log parameterPreprocessingLog) error {
if skipTaskParameterPreprocessingLog(log.ModelType) {
return nil
}
_, err := s.store.CreateTaskParamPreprocessingLog(ctx, store.CreateTaskParamPreprocessingLogInput{
TaskID: taskID,
AttemptID: attemptID,
@ -549,6 +552,15 @@ func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID s
return err
}
func skipTaskParameterPreprocessingLog(modelType string) bool {
switch strings.TrimSpace(modelType) {
case "text_generate", "chat", "responses", "text":
return true
default:
return false
}
}
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
if simulated {
return s.clients["simulation"]