fix: preprocess message content by model capability
This commit is contained in:
parent
b8a716169f
commit
f1535a94c2
@ -11,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type paramProcessContext struct {
|
||||
kind string
|
||||
modelCapability map[string]any
|
||||
candidate store.RuntimeModelCandidate
|
||||
log *parameterPreprocessingLog
|
||||
@ -58,6 +59,7 @@ func NewParamProcessorChain() ParamProcessorChain {
|
||||
processors: []paramProcessor{
|
||||
resolutionNormalizeProcessor{},
|
||||
aspectRatioProcessor{},
|
||||
messageContentProcessor{},
|
||||
contentFilterProcessor{},
|
||||
inputAudioProcessor{},
|
||||
durationProcessor{},
|
||||
@ -91,6 +93,7 @@ func preprocessRequestWithLog(kind string, body map[string]any, candidate store.
|
||||
},
|
||||
}
|
||||
context := ¶mProcessContext{
|
||||
kind: kind,
|
||||
modelCapability: effectiveModelCapability(candidate),
|
||||
candidate: candidate,
|
||||
log: &log,
|
||||
@ -333,6 +336,193 @@ func (aspectRatioProcessor) Process(params map[string]any, modelType string, con
|
||||
return true
|
||||
}
|
||||
|
||||
type messageContentProcessor struct{}
|
||||
|
||||
func (messageContentProcessor) Name() string { return "MessageContentProcessor" }
|
||||
|
||||
func (messageContentProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool {
|
||||
return isTextGenerationKind(context.kind) && params["messages"] != nil
|
||||
}
|
||||
|
||||
func (messageContentProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool {
|
||||
messages, changed := processMessageListContent(params["messages"], context)
|
||||
if changed {
|
||||
params["messages"] = messages
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func processMessageListContent(value any, context *paramProcessContext) ([]any, bool) {
|
||||
rawMessages, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out := make([]any, 0, len(rawMessages))
|
||||
changed := false
|
||||
for messageIndex, rawMessage := range rawMessages {
|
||||
message, ok := rawMessage.(map[string]any)
|
||||
if !ok {
|
||||
out = append(out, rawMessage)
|
||||
continue
|
||||
}
|
||||
nextMessage := cloneMap(message)
|
||||
if contentParts, ok := message["content"].([]any); ok {
|
||||
nextContent, contentChanged := processMessageContentParts(
|
||||
contentParts,
|
||||
fmt.Sprintf("messages[%d].content", messageIndex),
|
||||
context,
|
||||
)
|
||||
if contentChanged {
|
||||
nextMessage["content"] = nextContent
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
out = append(out, nextMessage)
|
||||
}
|
||||
return out, changed
|
||||
}
|
||||
|
||||
func processMessageContentParts(parts []any, basePath string, context *paramProcessContext) ([]any, bool) {
|
||||
out := make([]any, 0, len(parts))
|
||||
changed := false
|
||||
for partIndex, rawPart := range parts {
|
||||
part, ok := rawPart.(map[string]any)
|
||||
if !ok {
|
||||
out = append(out, rawPart)
|
||||
continue
|
||||
}
|
||||
if replacement, replacementChanged := messageContentPartReplacement(part, context); replacementChanged {
|
||||
out = append(out, replacement)
|
||||
context.recordChange(
|
||||
"MessageContentProcessor",
|
||||
"convert",
|
||||
fmt.Sprintf("%s[%d]", basePath, partIndex),
|
||||
part,
|
||||
replacement,
|
||||
messageContentConversionReason(part),
|
||||
messageContentCapabilityPath(part),
|
||||
messageContentCapabilityValue(part, context),
|
||||
)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, cloneMap(part))
|
||||
}
|
||||
return out, changed
|
||||
}
|
||||
|
||||
func messageContentPartReplacement(part map[string]any, context *paramProcessContext) (map[string]any, bool) {
|
||||
switch {
|
||||
case isImageContent(part):
|
||||
if modelSupportsMessageModality(context, "image_analysis") {
|
||||
return nil, false
|
||||
}
|
||||
if url := imageURLFromContentPart(part); url != "" {
|
||||
return map[string]any{"type": "text", "text": "Image link: " + url}, true
|
||||
}
|
||||
case isVideoContent(part):
|
||||
if modelSupportsMessageModality(context, "video_understanding") {
|
||||
return nil, false
|
||||
}
|
||||
if url := videoURLFromContentPart(part); url != "" {
|
||||
return map[string]any{"type": "text", "text": "video URL: " + url}, true
|
||||
}
|
||||
case isAudioContent(part) || stringFromAny(part["type"]) == "input_audio":
|
||||
if modelSupportsMessageModality(context, "audio_understanding") {
|
||||
return nil, false
|
||||
}
|
||||
if url := audioURLFromContentPart(part); url != "" {
|
||||
return map[string]any{"type": "text", "text": "audio URL: " + url}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func messageContentConversionReason(part map[string]any) string {
|
||||
switch {
|
||||
case isImageContent(part):
|
||||
return "模型不支持图像理解,已将 image_url 转为文本链接。"
|
||||
case isVideoContent(part):
|
||||
return "模型不支持视频理解,已将 video_url 转为文本链接。"
|
||||
default:
|
||||
return "模型不支持音频理解,已将音频输入转为文本链接。"
|
||||
}
|
||||
}
|
||||
|
||||
func messageContentCapabilityPath(part map[string]any) string {
|
||||
switch {
|
||||
case isImageContent(part):
|
||||
return "capabilities.image_analysis"
|
||||
case isVideoContent(part):
|
||||
return "capabilities.video_understanding"
|
||||
default:
|
||||
return "capabilities.audio_understanding"
|
||||
}
|
||||
}
|
||||
|
||||
func messageContentCapabilityValue(part map[string]any, context *paramProcessContext) any {
|
||||
if context == nil {
|
||||
return nil
|
||||
}
|
||||
switch {
|
||||
case isImageContent(part):
|
||||
return capabilityValue(context.modelCapability, "image_analysis", "")
|
||||
case isVideoContent(part):
|
||||
return capabilityValue(context.modelCapability, "video_understanding", "")
|
||||
default:
|
||||
return capabilityValue(context.modelCapability, "audio_understanding", "")
|
||||
}
|
||||
}
|
||||
|
||||
func modelSupportsMessageModality(context *paramProcessContext, capabilityName string) bool {
|
||||
if context == nil {
|
||||
return false
|
||||
}
|
||||
capabilities := context.modelCapability
|
||||
if capabilityForType(capabilities, capabilityName) != nil {
|
||||
return true
|
||||
}
|
||||
if capabilityForType(capabilities, "omni") != nil {
|
||||
return true
|
||||
}
|
||||
originalTypes := stringListFromAny(capabilities["originalTypes"])
|
||||
return containsString(originalTypes, capabilityName) || containsString(originalTypes, "omni")
|
||||
}
|
||||
|
||||
func imageURLFromContentPart(part map[string]any) string {
|
||||
return urlFromNestedContentPart(part, "image_url", "url", "imageUrl")
|
||||
}
|
||||
|
||||
func videoURLFromContentPart(part map[string]any) string {
|
||||
return urlFromNestedContentPart(part, "video_url", "url", "videoUrl")
|
||||
}
|
||||
|
||||
func audioURLFromContentPart(part map[string]any) string {
|
||||
if stringFromAny(part["type"]) == "input_audio" {
|
||||
if audio, ok := part["input_audio"].(map[string]any); ok {
|
||||
if url := firstNonEmptyString(stringFromAny(audio["data"]), stringFromAny(audio["url"])); url != "" {
|
||||
return url
|
||||
}
|
||||
}
|
||||
}
|
||||
return urlFromNestedContentPart(part, "audio_url", "url", "audioUrl")
|
||||
}
|
||||
|
||||
func urlFromNestedContentPart(part map[string]any, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
value := part[key]
|
||||
if url := stringFromAny(value); url != "" {
|
||||
return url
|
||||
}
|
||||
if nested, ok := value.(map[string]any); ok {
|
||||
if url := stringFromAny(nested["url"]); url != "" {
|
||||
return url
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type contentFilterProcessor struct{}
|
||||
|
||||
func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" }
|
||||
@ -1239,6 +1429,9 @@ func capabilityValue(capabilities map[string]any, modelType string, key string)
|
||||
if capability == nil {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(key) == "" {
|
||||
return cloneMap(capability)
|
||||
}
|
||||
return cloneAny(capability[key])
|
||||
}
|
||||
|
||||
|
||||
@ -123,6 +123,163 @@ func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) {
|
||||
t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes)
|
||||
}
|
||||
|
||||
func TestParamProcessorChatConvertsUnsupportedMediaMessageContentToText(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"model": "text-only",
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "describe these"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
|
||||
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
|
||||
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}},
|
||||
map[string]any{"type": "input_audio", "input_audio": map[string]any{"data": "https://example.com/input.wav"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelType: "text_generate",
|
||||
Capabilities: map[string]any{
|
||||
"text_generate": map[string]any{},
|
||||
"originalTypes": []any{"text_generate"},
|
||||
},
|
||||
}
|
||||
|
||||
result := preprocessRequestWithLog("chat.completions", body, candidate)
|
||||
messages, _ := result.Body["messages"].([]any)
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected one message, got %+v", result.Body["messages"])
|
||||
}
|
||||
message, _ := messages[0].(map[string]any)
|
||||
content, _ := message["content"].([]any)
|
||||
if len(content) != 5 {
|
||||
t.Fatalf("expected five content parts, got %+v", message["content"])
|
||||
}
|
||||
expectedText := []string{
|
||||
"describe these",
|
||||
"Image link: https://example.com/image.png",
|
||||
"video URL: https://example.com/video.mp4",
|
||||
"audio URL: https://example.com/audio.mp3",
|
||||
"audio URL: https://example.com/input.wav",
|
||||
}
|
||||
for index, expected := range expectedText {
|
||||
part, _ := content[index].(map[string]any)
|
||||
if stringFromAny(part["text"]) != expected {
|
||||
t.Fatalf("content[%d] text = %q, want %q; all=%+v", index, stringFromAny(part["text"]), expected, content)
|
||||
}
|
||||
}
|
||||
if len(result.Log.Changes) != 4 {
|
||||
t.Fatalf("expected four media conversion changes, got %+v", result.Log.Changes)
|
||||
}
|
||||
expectedCapabilityPaths := map[string]bool{
|
||||
"capabilities.image_analysis": false,
|
||||
"capabilities.video_understanding": false,
|
||||
"capabilities.audio_understanding": false,
|
||||
}
|
||||
for _, change := range result.Log.Changes {
|
||||
if _, ok := expectedCapabilityPaths[change.CapabilityPath]; ok {
|
||||
expectedCapabilityPaths[change.CapabilityPath] = true
|
||||
}
|
||||
}
|
||||
for path, found := range expectedCapabilityPaths {
|
||||
if !found {
|
||||
t.Fatalf("expected conversion log for %s, got %+v", path, result.Log.Changes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamProcessorChatKeepsOmniMessageContent(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"model": "omni",
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
|
||||
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
|
||||
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelType: "text_generate",
|
||||
Capabilities: map[string]any{
|
||||
"text_generate": map[string]any{},
|
||||
"omni": map[string]any{},
|
||||
"originalTypes": []any{"text_generate", "omni"},
|
||||
},
|
||||
}
|
||||
|
||||
result := preprocessRequestWithLog("chat.completions", body, candidate)
|
||||
if result.Log.Changed {
|
||||
t.Fatalf("omni model should keep message media content unchanged, got %+v", result.Log.Changes)
|
||||
}
|
||||
messages, _ := result.Body["messages"].([]any)
|
||||
message, _ := messages[0].(map[string]any)
|
||||
content, _ := message["content"].([]any)
|
||||
for _, item := range content {
|
||||
part, _ := item.(map[string]any)
|
||||
if stringFromAny(part["type"]) == "text" {
|
||||
t.Fatalf("media content should not be converted for omni model: %+v", content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamProcessorChatConvertsOnlyUnsupportedModalities(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"model": "vision-only",
|
||||
"messages": []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}},
|
||||
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelType: "text_generate",
|
||||
Capabilities: map[string]any{
|
||||
"text_generate": map[string]any{},
|
||||
"image_analysis": map[string]any{},
|
||||
"originalTypes": []any{"text_generate", "image_analysis"},
|
||||
},
|
||||
}
|
||||
|
||||
result := preprocessRequestWithLog("chat.completions", body, candidate)
|
||||
messages, _ := result.Body["messages"].([]any)
|
||||
message, _ := messages[0].(map[string]any)
|
||||
content, _ := message["content"].([]any)
|
||||
first, _ := content[0].(map[string]any)
|
||||
second, _ := content[1].(map[string]any)
|
||||
if stringFromAny(first["type"]) != "image_url" {
|
||||
t.Fatalf("image content should be kept when image_analysis is supported: %+v", content)
|
||||
}
|
||||
if stringFromAny(second["text"]) != "video URL: https://example.com/video.mp4" {
|
||||
t.Fatalf("video content should be converted, got %+v", second)
|
||||
}
|
||||
if len(result.Log.Changes) != 1 || result.Log.Changes[0].CapabilityPath != "capabilities.video_understanding" {
|
||||
t.Fatalf("expected only video conversion to be logged, got %+v", result.Log.Changes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSkipTaskParameterPreprocessingLogForTextModelTypes(t *testing.T) {
|
||||
for _, modelType := range []string{"text_generate", "chat", "responses", "text"} {
|
||||
if !skipTaskParameterPreprocessingLog(modelType) {
|
||||
t.Fatalf("%s should skip task parameter preprocessing log", modelType)
|
||||
}
|
||||
}
|
||||
for _, modelType := range []string{"image_generate", "image_edit", "video_generate", "omni_video"} {
|
||||
if skipTaskParameterPreprocessingLog(modelType) {
|
||||
t.Fatalf("%s should keep task parameter preprocessing log", modelType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"model": "Seedance",
|
||||
|
||||
@ -531,6 +531,9 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user
|
||||
}
|
||||
|
||||
func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID string, attemptID string, attemptNo int, candidate store.RuntimeModelCandidate, log parameterPreprocessingLog) error {
|
||||
if skipTaskParameterPreprocessingLog(log.ModelType) {
|
||||
return nil
|
||||
}
|
||||
_, err := s.store.CreateTaskParamPreprocessingLog(ctx, store.CreateTaskParamPreprocessingLogInput{
|
||||
TaskID: taskID,
|
||||
AttemptID: attemptID,
|
||||
@ -549,6 +552,15 @@ func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID s
|
||||
return err
|
||||
}
|
||||
|
||||
func skipTaskParameterPreprocessingLog(modelType string) bool {
|
||||
switch strings.TrimSpace(modelType) {
|
||||
case "text_generate", "chat", "responses", "text":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
|
||||
if simulated {
|
||||
return s.clients["simulation"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user