diff --git a/.gitignore b/.gitignore index 8313207..e32f60d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ node_modules/ apps/api/bin/ apps/api/tmp/ +apps/api/data/ coverage/ diff --git a/apps/api/internal/clients/clients_test.go b/apps/api/internal/clients/clients_test.go index 6e6ca0b..5476949 100644 --- a/apps/api/internal/clients/clients_test.go +++ b/apps/api/internal/clients/clients_test.go @@ -329,6 +329,12 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { var gotModel string var gotText string var gotFirstFrameRole string + var gotDuration float64 + var gotRatio string + var gotResolution string + var gotSeed float64 + var gotCameraFixed bool + var gotWatermark bool var submittedRemoteTaskID string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") @@ -343,6 +349,17 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { if body["prompt"] != nil || body["first_frame"] != nil { t.Fatalf("video convenience fields leaked upstream: %+v", body) } + for _, key := range []string{"duration_seconds", "aspect_ratio", "audio", "cameraFixed"} { + if _, ok := body[key]; ok { + t.Fatalf("volces video task body should not include top-level %s: %+v", key, body) + } + } + gotDuration, _ = body["duration"].(float64) + gotRatio, _ = body["ratio"].(string) + gotResolution, _ = body["resolution"].(string) + gotSeed, _ = body["seed"].(float64) + gotCameraFixed, _ = body["camera_fixed"].(bool) + gotWatermark, _ = body["watermark"].(bool) content, _ := body["content"].([]any) textItem, _ := content[0].(map[string]any) gotText, _ = textItem["text"].(string) @@ -375,6 +392,10 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { "first_frame": "https://example.com/first.png", "duration": 6, "aspect_ratio": "16:9", + "resolution": "720p", + "seed": 11, + "cameraFixed": false, + "watermark": true, }, Candidate: store.RuntimeModelCandidate{ BaseURL: server.URL, @@ -406,10 +427,11 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { if gotModel != "doubao-seedance-2-0-260128" || gotFirstFrameRole != "first_frame" { t.Fatalf("unexpected submitted model=%s role=%s", gotModel, gotFirstFrameRole) } - for _, fragment := range []string{"A clean product reveal", "--dur 6", "--ratio 16:9", "--watermark false", "--seed -1"} { - if !strings.Contains(gotText, fragment) { - t.Fatalf("expected text to contain %q, got %q", fragment, gotText) - } + if gotText != "A clean product reveal" { + t.Fatalf("video params should not be appended to prompt text, got %q", gotText) + } + if gotDuration != 6 || gotRatio != "16:9" || gotResolution != "720p" || gotSeed != 11 || gotCameraFixed != false || gotWatermark != true { + t.Fatalf("unexpected submitted video params duration=%v ratio=%s resolution=%s seed=%v camera_fixed=%v watermark=%v", gotDuration, gotRatio, gotResolution, gotSeed, gotCameraFixed, gotWatermark) } data, _ := response.Result["data"].([]any) item, _ := data[0].(map[string]any) @@ -418,6 +440,181 @@ func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { } } +func TestVolcesClientVideoRejectsDuplicateFirstFrameBeforeSubmit(t *testing.T) { + var submitted bool + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + submitted = true + t.Fatalf("duplicate first_frame request should not be submitted upstream") + })) + defer server.Close() + + _, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ + Kind: "videos.generations", + ModelType: "image_to_video", + Model: "豆包Seedance", + Body: map[string]any{ + "model": "豆包Seedance", + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/first.png"}}, + map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/second.png"}}, + }, + }, + Candidate: store.RuntimeModelCandidate{ + BaseURL: server.URL, + ProviderModelName: "doubao-seedance-1-5-pro-251215", + Credentials: map[string]any{"apiKey": "volces-key"}, + }, + }) + if err == nil || ErrorCode(err) != "invalid_parameter" { + t.Fatalf("expected local invalid_parameter error, got %v", err) + } + if submitted { + t.Fatal("request was submitted upstream") + } +} + +func TestVolcesVideoBodyAllowsOnlyTaskPayloadFields(t *testing.T) { + body := volcesVideoBody(Request{ + Kind: "videos.generations", + ModelType: "omni_video", + Model: "豆包Seedance", + Body: map[string]any{ + "model": "豆包Seedance", + "duration": 8, + "duration_seconds": 8, + "aspect_ratio": "9:16", + "resolution": "720p", + "audio": true, + "callback_url": "https://example.com/callback", + "returnLastFrame": true, + "executionExpiresAfter": 3600, + "draft": false, + "cameraFixed": false, + "watermark": true, + "seed": -1, + "task_id": "local-task-id", + "runMode": "simulation", + "fps": 24, + "content": []any{ + map[string]any{"type": "text", "text": "Use <<>> in a product reveal"}, + map[string]any{ + "type": "element", + "element": map[string]any{ + "inline_element": map[string]any{ + "name": "subject", + "frontal_image_url": "https://example.com/subject.png", + "refer_images": []any{map[string]any{"url": "https://example.com/side.png", "slot_key": "side"}}, + }, + }, + }, + map[string]any{ + "type": "image_url", + "role": "unexpected_role", + "name": "drop-me", + "image_url": map[string]any{"url": "https://example.com/ref.png", "extra": "drop-me"}, + }, + map[string]any{ + "type": "video_url", + "duration": 3, + "video_url": map[string]any{ + "url": "https://example.com/ref.mp4", + "refer_type": "feature", + "keep_original_sound": "yes", + "extra": "drop-me", + }, + }, + map[string]any{ + "type": "audio_url", + "audio_url": map[string]any{"url": "https://example.com/ref.mp3", "extra": "drop-me"}, + }, + }, + }, + Candidate: store.RuntimeModelCandidate{ + ModelName: "豆包Seedance", + ProviderModelName: "doubao-seedance-2-0-260128", + Credentials: map[string]any{"apiKey": "volces-key"}, + }, + }) + + allowedTopLevel := map[string]bool{ + "model": true, "content": true, "callback_url": true, "return_last_frame": true, "execution_expires_after": true, + "generate_audio": true, "draft": true, "resolution": true, "ratio": true, "duration": true, + "seed": true, "camera_fixed": true, "watermark": true, + } + for key := range body { + if !allowedTopLevel[key] { + t.Fatalf("unexpected top-level volces field %q in %+v", key, body) + } + } + if body["model"] != "doubao-seedance-2-0-260128" || + body["generate_audio"] != true || + body["callback_url"] != "https://example.com/callback" || + body["return_last_frame"] != true || + body["execution_expires_after"] != 3600 || + body["draft"] != false || + body["resolution"] != "720p" || + body["ratio"] != "9:16" || + body["duration"] != 8 || + body["seed"] != -1 || + body["camera_fixed"] != false || + body["watermark"] != true { + t.Fatalf("unexpected direct video fields: %+v", body) + } + + content, ok := body["content"].([]map[string]any) + if !ok || len(content) != 5 { + t.Fatalf("unexpected sanitized content: %#v", body["content"]) + } + text := content[0] + if text["type"] != "text" || strings.Contains(text["text"].(string), "--dur") || strings.Contains(text["text"].(string), "--ratio") { + t.Fatalf("video params should not be appended to the text item: %+v", text) + } + elementImage := content[1] + if elementImage["type"] != "image_url" || elementImage["role"] != "reference_image" { + t.Fatalf("referenced element should be converted to reference image: %+v", elementImage) + } + imageURL, _ := elementImage["image_url"].(map[string]any) + if imageURL["url"] != "https://example.com/subject.png" || len(imageURL) != 1 { + t.Fatalf("element image payload should only include url: %+v", imageURL) + } + referenceImage := content[2] + if referenceImage["role"] != "reference_image" || referenceImage["name"] != nil { + t.Fatalf("image references should be role-normalized and scrubbed: %+v", referenceImage) + } + videoItem := content[3] + videoURL, _ := videoItem["video_url"].(map[string]any) + if videoItem["role"] != "reference_video" || videoURL["url"] != "https://example.com/ref.mp4" || videoURL["refer_type"] != "feature" || videoURL["extra"] != nil { + t.Fatalf("video references should keep only allowed nested fields: %+v", videoItem) + } + audioItem := content[4] + audioURL, _ := audioItem["audio_url"].(map[string]any) + if audioItem["role"] != "reference_audio" || audioURL["url"] != "https://example.com/ref.mp3" || len(audioURL) != 1 { + t.Fatalf("audio references should keep only url: %+v", audioItem) + } +} + +func TestVolcesVideoBodyPrefersFramesOverDuration(t *testing.T) { + body := volcesVideoBody(Request{ + Kind: "videos.generations", + ModelType: "video_generate", + Body: map[string]any{ + "prompt": "A quick camera move", + "duration": 8, + "frames": 57, + }, + Candidate: store.RuntimeModelCandidate{ + ProviderModelName: "doubao-seedance-1-0-pro-250528", + }, + }) + if body["frames"] != 57 { + t.Fatalf("frames should be passed through as the official duration control: %+v", body) + } + if _, ok := body["duration"]; ok { + t.Fatalf("duration should not be sent when frames is present: %+v", body) + } +} + func TestVolcesClientVideoResumePollsExistingTaskID(t *testing.T) { var submitCalled bool var pollPath string diff --git a/apps/api/internal/clients/simulation.go b/apps/api/internal/clients/simulation.go index c4e7085..2cbff49 100644 --- a/apps/api/internal/clients/simulation.go +++ b/apps/api/internal/clients/simulation.go @@ -339,5 +339,12 @@ func firstNonEmptyPrompt(body map[string]any, fallback string) string { return value } } + for _, item := range contentItems(body["content"]) { + if stringValue(item, "type") == "text" { + if value := strings.TrimSpace(stringValue(item, "text")); value != "" { + return value + } + } + } return fallback } diff --git a/apps/api/internal/clients/volces.go b/apps/api/internal/clients/volces.go index 020d134..60ec856 100644 --- a/apps/api/internal/clients/volces.go +++ b/apps/api/internal/clients/volces.go @@ -7,10 +7,14 @@ import ( "fmt" "math" "net/http" + "regexp" + "strconv" "strings" "time" ) +var volcesElementReferencePattern = regexp.MustCompile(`(?i)<<<[[:space:]]*element[_-]?([0-9]+)[[:space:]]*>>>|@element([0-9]+)`) + type VolcesClient struct { HTTPClient *http.Client } @@ -72,6 +76,9 @@ func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey stri upstreamTaskID := strings.TrimSpace(request.RemoteTaskID) if upstreamTaskID == "" { body := volcesVideoBody(request) + if err := validateVolcesVideoTaskBody(body); err != nil { + return Response{}, err + } submitResult, requestID, err := c.postJSON(ctx, request, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body) submitRequestID = requestID if err != nil { @@ -215,11 +222,9 @@ func volcesVideoBody(request Request) map[string]any { content = buildVolcesContentFromBody(body) } appendMultiShotTimeline(&content) + convertVolcesElementsToImageReferences(&content) normalizeVolcesContentRoles(content) - appendVolcesVideoParams(&content, body) - body["content"] = content - stripVolcesVideoConvenienceFields(body) - return body + return volcesVideoTaskBody(body, content) } func cleanProviderBody(body map[string]any) map[string]any { @@ -286,56 +291,267 @@ func buildVolcesContentFromBody(body map[string]any) []map[string]any { return content } -func stripVolcesVideoConvenienceFields(body map[string]any) { - for _, key := range []string{ - "prompt", - "input", - "image", - "images", - "image_url", - "imageUrl", - "image_urls", - "imageUrls", - "reference_image", - "referenceImage", - "first_frame", - "firstFrame", - "last_frame", - "lastFrame", - "video", - "video_url", - "videoUrl", - "reference_video", - "referenceVideo", - "audio_url", - "audioUrl", - "reference_audio", - "referenceAudio", - } { - delete(body, key) +func volcesVideoTaskBody(body map[string]any, content []map[string]any) map[string]any { + out := map[string]any{ + "model": body["model"], + "content": sanitizeVolcesVideoContent(content), + } + addVolcesVideoTaskParams(out, body) + return out +} + +func validateVolcesVideoTaskBody(body map[string]any) error { + firstFrameCount := 0 + lastFrameCount := 0 + for _, item := range contentItems(body["content"]) { + if stringFromAny(item["type"]) != "image_url" { + continue + } + switch stringFromAny(item["role"]) { + case "first_frame": + firstFrameCount++ + case "last_frame": + lastFrameCount++ + } + } + if firstFrameCount > 1 { + return &ClientError{ + Code: "invalid_parameter", + Message: fmt.Sprintf("content contains %d first_frame image items; expected at most one first frame image content", firstFrameCount), + StatusCode: 400, + Retryable: false, + } + } + if lastFrameCount > 1 { + return &ClientError{ + Code: "invalid_parameter", + Message: fmt.Sprintf("content contains %d last_frame image items; expected at most one last frame image content", lastFrameCount), + StatusCode: 400, + Retryable: false, + } + } + return nil +} + +func addVolcesVideoTaskParams(out map[string]any, body map[string]any) { + copyVolcesStringParam(out, "callback_url", body, "callback_url", "callbackUrl") + copyVolcesBoolParam(out, "return_last_frame", body, "return_last_frame", "returnLastFrame") + copyVolcesIntParam(out, "execution_expires_after", body, "execution_expires_after", "executionExpiresAfter") + copyVolcesBoolParam(out, "generate_audio", body, "generate_audio", "generateAudio", "audio") + copyVolcesBoolParam(out, "draft", body, "draft") + copyVolcesStringParam(out, "resolution", body, "resolution", "size") + copyVolcesStringParam(out, "ratio", body, "ratio", "aspect_ratio", "aspectRatio") + if copyVolcesIntParam(out, "frames", body, "frames") { + delete(out, "duration") + } else { + copyVolcesIntParam(out, "duration", body, "duration", "duration_seconds", "durationSeconds", "dur") + } + copyVolcesIntParam(out, "seed", body, "seed") + copyVolcesBoolParam(out, "camera_fixed", body, "camera_fixed", "cameraFixed", "camerafixed", "cf") + copyVolcesBoolParam(out, "watermark", body, "watermark") +} + +func copyVolcesStringParam(out map[string]any, target string, body map[string]any, keys ...string) bool { + for _, key := range keys { + if value := strings.TrimSpace(stringFromAny(body[key])); value != "" { + out[target] = value + return true + } + } + return false +} + +func copyVolcesIntParam(out map[string]any, target string, body map[string]any, keys ...string) bool { + for _, key := range keys { + if value, ok := volcesIntFromAny(body[key]); ok { + out[target] = value + return true + } + } + return false +} + +func copyVolcesBoolParam(out map[string]any, target string, body map[string]any, keys ...string) bool { + for _, key := range keys { + if value, ok := volcesBoolFromAny(body[key]); ok { + out[target] = value + return true + } + } + return false +} + +func volcesIntFromAny(value any) (int, bool) { + switch typed := value.(type) { + case nil: + return 0, false + case int: + return typed, true + case int64: + return int(typed), true + case float64: + return int(math.Round(typed)), true + case string: + text := strings.TrimSpace(typed) + if text == "" { + return 0, false + } + if parsed, err := strconv.ParseFloat(text, 64); err == nil { + return int(math.Round(parsed)), true + } + return 0, false + default: + return 0, false } } -func contentItems(value any) []map[string]any { - rawItems, ok := value.([]any) - if !ok { - return nil +func volcesBoolFromAny(value any) (bool, bool) { + switch typed := value.(type) { + case nil: + return false, false + case bool: + return typed, true + case int: + if typed == 1 { + return true, true + } + if typed == 0 { + return false, true + } + case int64: + if typed == 1 { + return true, true + } + if typed == 0 { + return false, true + } + case float64: + if typed == 1 { + return true, true + } + if typed == 0 { + return false, true + } + case string: + normalized := strings.ToLower(strings.TrimSpace(typed)) + if normalized == "true" || normalized == "1" { + return true, true + } + if normalized == "false" || normalized == "0" { + return false, true + } } - out := make([]map[string]any, 0, len(rawItems)) - for _, raw := range rawItems { - item, ok := raw.(map[string]any) - if !ok { - continue + return false, false +} + +func sanitizeVolcesVideoContent(content []map[string]any) []map[string]any { + out := make([]map[string]any, 0, len(content)) + for _, item := range content { + switch stringFromAny(item["type"]) { + case "text": + out = append(out, map[string]any{ + "type": "text", + "text": strings.TrimSpace(stringFromAny(item["text"])), + }) + case "image_url": + url := volcesNestedURL(item, "image_url") + if url == "" { + continue + } + out = append(out, map[string]any{ + "type": "image_url", + "role": volcesImageRole(item), + "image_url": map[string]any{"url": url}, + }) + case "video_url": + url := volcesNestedURL(item, "video_url") + if url == "" { + continue + } + videoURL := map[string]any{"url": url} + if value := strings.TrimSpace(stringFromAny(mapFromAny(item["video_url"])["refer_type"])); value != "" { + videoURL["refer_type"] = value + } + if value := strings.TrimSpace(stringFromAny(mapFromAny(item["video_url"])["keep_original_sound"])); value != "" { + videoURL["keep_original_sound"] = value + } + out = append(out, map[string]any{ + "type": "video_url", + "role": "reference_video", + "video_url": videoURL, + }) + case "audio_url": + url := volcesNestedURL(item, "audio_url") + if url == "" { + continue + } + out = append(out, map[string]any{ + "type": "audio_url", + "role": "reference_audio", + "audio_url": map[string]any{"url": url}, + }) } - copied := map[string]any{} - for key, value := range item { - copied[key] = value - } - out = append(out, copied) + } + if len(out) == 0 { + return []map[string]any{{"type": "text", "text": ""}} } return out } +func volcesImageRole(item map[string]any) string { + switch strings.TrimSpace(stringFromAny(item["role"])) { + case "first_frame": + return "first_frame" + case "last_frame": + return "last_frame" + default: + return "reference_image" + } +} + +func volcesNestedURL(item map[string]any, key string) string { + nested := mapFromAny(item[key]) + return strings.TrimSpace(stringFromAny(nested["url"])) +} + +func mapFromAny(value any) map[string]any { + if object, ok := value.(map[string]any); ok { + return object + } + return nil +} + +func contentItems(value any) []map[string]any { + switch typed := value.(type) { + case []any: + out := make([]map[string]any, 0, len(typed)) + for _, raw := range typed { + item, ok := raw.(map[string]any) + if !ok { + continue + } + copied := map[string]any{} + for key, value := range item { + copied[key] = value + } + out = append(out, copied) + } + return out + case []map[string]any: + out := make([]map[string]any, 0, len(typed)) + for _, item := range typed { + copied := map[string]any{} + for key, value := range item { + copied[key] = value + } + out = append(out, copied) + } + return out + default: + return nil + } +} + func normalizeVolcesContentRoles(content []map[string]any) { for _, item := range content { itemType := strings.TrimSpace(stringFromAny(item["type"])) @@ -353,32 +569,115 @@ func normalizeVolcesContentRoles(content []map[string]any) { } } -func appendVolcesVideoParams(content *[]map[string]any, body map[string]any) { - textItem := ensureTextContent(content) - current := strings.TrimSpace(stringFromAny(textItem["text"])) - values := []struct { - key string - value any - }{ - {"dur", firstPresent(body["duration"], body["dur"])}, - {"ratio", firstPresent(body["aspect_ratio"], body["aspectRatio"], body["ratio"])}, - {"fps", firstPresent(body["framespersecond"], body["framesPerSecond"], body["fps"])}, - {"watermark", firstPresent(body["watermark"], false)}, - {"seed", firstPresent(body["seed"], -1)}, - {"cf", firstPresent(body["camerafixed"], body["cameraFixed"])}, - {"rs", firstPresent(body["resolution"], body["size"])}, - } - for _, item := range values { - valueText := volcesParamString(item.value) - if valueText == "" || strings.Contains(current, "--"+item.key) { +func convertVolcesElementsToImageReferences(content *[]map[string]any) { + referenced := referencedVolcesElementIndexes(*content) + out := make([]map[string]any, 0, len(*content)) + elementIndex := 0 + for _, item := range *content { + if stringFromAny(item["type"]) != "element" { + out = append(out, item) continue } - if current != "" { - current += " " + elementIndex++ + if !referenced[elementIndex] { + continue } - current += "--" + item.key + " " + valueText + url := volcesElementFrontalImageURL(item) + if url == "" { + continue + } + role := stringFromAny(item["role"]) + if role != "first_frame" && role != "last_frame" { + role = "reference_image" + } + out = append(out, map[string]any{ + "type": "image_url", + "role": role, + "image_url": map[string]any{"url": url}, + }) + } + *content = out +} + +func referencedVolcesElementIndexes(content []map[string]any) map[int]bool { + out := map[int]bool{} + for _, item := range content { + if stringFromAny(item["type"]) != "text" { + continue + } + text := stringFromAny(item["text"]) + if strings.TrimSpace(text) == "" { + continue + } + for _, match := range volcesElementReferencePattern.FindAllStringSubmatch(text, -1) { + raw := "" + if len(match) > 1 && match[1] != "" { + raw = match[1] + } else if len(match) > 2 { + raw = match[2] + } + index, err := strconv.Atoi(raw) + if err == nil && index > 0 { + out[index] = true + } + } + } + return out +} + +func volcesElementFrontalImageURL(item map[string]any) string { + element := mapFromAny(item["element"]) + if element == nil { + return "" + } + inline := mapFromAny(element["inline_element"]) + for _, value := range []any{ + inline["frontal_image_url"], + element["frontal_image_url"], + element["front_image_url"], + element["image_url"], + } { + if url := strings.TrimSpace(stringFromAny(value)); url != "" { + return url + } + } + return volcesReferImageURL(firstPresent(inline["refer_images"], element["refer_images"])) +} + +func volcesReferImageURL(value any) string { + images := mapListFromAny(value) + firstURL := "" + for _, image := range images { + url := strings.TrimSpace(stringFromAny(image["url"])) + if url == "" { + continue + } + if firstURL == "" { + firstURL = url + } + slot := strings.ToLower(strings.TrimSpace(stringFromAny(image["slot_key"]))) + if slot == "frontal" || slot == "front" { + return url + } + } + return firstURL +} + +func mapListFromAny(value any) []map[string]any { + switch typed := value.(type) { + case []any: + out := make([]map[string]any, 0, len(typed)) + for _, item := range typed { + if object := mapFromAny(item); object != nil { + out = append(out, object) + } + } + return out + case []map[string]any: + return typed + default: + return nil } - textItem["text"] = current } func appendMultiShotTimeline(content *[]map[string]any) { @@ -625,31 +924,6 @@ func firstNonEmptyStringListFromAny(values ...any) []string { return nil } -func volcesParamString(value any) string { - switch typed := value.(type) { - case nil: - return "" - case string: - return strings.TrimSpace(typed) - case bool: - if typed { - return "true" - } - return "false" - case int: - return fmt.Sprintf("%d", typed) - case int64: - return fmt.Sprintf("%d", typed) - case float64: - if math.Mod(typed, 1) == 0 { - return fmt.Sprintf("%d", int64(typed)) - } - return fmt.Sprintf("%g", typed) - default: - return fmt.Sprintf("%v", typed) - } -} - func numericValue(value any, fallback float64) float64 { switch typed := value.(type) { case int: diff --git a/apps/api/internal/config/config.go b/apps/api/internal/config/config.go index e8498c2..7b2188d 100644 --- a/apps/api/internal/config/config.go +++ b/apps/api/internal/config/config.go @@ -7,6 +7,11 @@ import ( "strings" ) +const ( + DefaultLocalGeneratedStorageDir = "data/static/generated" + DefaultLocalUploadedStorageDir = "data/static/uploaded" +) + type Config struct { AppEnv string HTTPAddr string @@ -15,6 +20,9 @@ type Config struct { JWTSecret string ServerMainBaseURL string ServerMainInternalToken string + PublicBaseURL string + LocalGeneratedStorageDir string + LocalUploadedStorageDir string TaskProgressCallbackEnabled bool TaskProgressCallbackURL string TaskProgressCallbackTimeoutMS string @@ -38,6 +46,9 @@ func Load() Config { "/", ), ServerMainInternalToken: env("SERVER_MAIN_INTERNAL_TOKEN", ""), + PublicBaseURL: strings.TrimRight(env("AI_GATEWAY_PUBLIC_BASE_URL", env("PUBLIC_BASE_URL", "")), "/"), + LocalGeneratedStorageDir: env("AI_GATEWAY_GENERATED_STORAGE_DIR", env("LOCAL_GENERATED_STORAGE_DIR", env("AI_GATEWAY_STATIC_STORAGE_DIR", DefaultLocalGeneratedStorageDir))), + LocalUploadedStorageDir: env("AI_GATEWAY_UPLOADED_STORAGE_DIR", env("LOCAL_UPLOADED_STORAGE_DIR", DefaultLocalUploadedStorageDir)), TaskProgressCallbackEnabled: env("TASK_PROGRESS_CALLBACK_ENABLED", "true") == "true", TaskProgressCallbackURL: env("TASK_PROGRESS_CALLBACK_URL", strings.TrimRight(env("SERVER_MAIN_BASE_URL", "http://localhost:3000"), "/")+"/internal/platform/task-progress-callbacks", diff --git a/apps/api/internal/httpapi/file_upload_handlers.go b/apps/api/internal/httpapi/file_upload_handlers.go new file mode 100644 index 0000000..68db37a --- /dev/null +++ b/apps/api/internal/httpapi/file_upload_handlers.go @@ -0,0 +1,58 @@ +package httpapi + +import ( + "io" + "net/http" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/runner" +) + +const maxGatewayUploadBytes = 256 << 20 + +func (s *Server) uploadFile(w http.ResponseWriter, r *http.Request) { + r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes) + if err := r.ParseMultipartForm(32 << 20); err != nil { + writeError(w, http.StatusBadRequest, "invalid multipart upload") + return + } + file, header, err := r.FormFile("file") + if err != nil { + writeError(w, http.StatusBadRequest, "file is required") + return + } + defer file.Close() + payload, err := io.ReadAll(file) + if err != nil { + writeError(w, http.StatusBadRequest, "read upload file failed") + return + } + contentType := strings.TrimSpace(header.Header.Get("Content-Type")) + if contentType == "" && len(payload) > 0 { + contentType = http.DetectContentType(payload) + } + upload, err := s.runner.UploadFile(r.Context(), runner.FileUploadPayload{ + Bytes: payload, + ContentType: contentType, + FileName: header.Filename, + Source: firstNonEmptyFormValue(r, "source", "ai-gateway-openapi"), + }) + if err != nil { + s.logger.Error("upload file failed", "error", err) + status := http.StatusBadGateway + if clients.ErrorCode(err) == "upload_no_channel" { + status = http.StatusServiceUnavailable + } + writeError(w, status, err.Error()) + return + } + writeJSON(w, http.StatusOK, upload) +} + +func firstNonEmptyFormValue(r *http.Request, key string, fallback string) string { + if value := strings.TrimSpace(r.FormValue(key)); value != "" { + return value + } + return fallback +} diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index 16ac3fd..a8ea8de 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -597,7 +597,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { status := statusFromRunError(runErr) errorPayload := map[string]any{ "code": runErrorCode(runErr), - "message": runErr.Error(), + "message": runErrorMessage(runErr), "status": status, } if result.Task.ID != "" { @@ -606,6 +606,9 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { if result.Task.RequestID != "" { errorPayload["requestId"] = result.Task.RequestID } + for key, value := range runErrorDetails(runErr) { + errorPayload[key] = value + } sendSSE(w, "error", map[string]any{"error": errorPayload}) if flusher != nil { flusher.Flush() @@ -626,7 +629,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { if !requestStillConnected(r) { return } - writeError(w, statusFromRunError(runErr), runErr.Error(), runErrorCode(runErr)) + writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr)) return } if !requestStillConnected(r) { @@ -742,6 +745,138 @@ func runErrorCode(err error) string { return clients.ErrorCode(err) } +func runErrorMessage(err error) string { + if err == nil { + return "" + } + if summary := rateLimitErrorSummary(err); summary != "" { + return err.Error() + ";" + summary + } + return err.Error() +} + +func runErrorDetails(err error) map[string]any { + if detail := rateLimitErrorDetail(err); len(detail) > 0 { + return map[string]any{"rateLimit": detail} + } + return nil +} + +func rateLimitErrorSummary(err error) string { + var limitErr *store.RateLimitExceededError + if !errors.As(err, &limitErr) { + return "" + } + scopeLabel := "限流对象" + switch limitErr.ScopeType { + case "user_group": + scopeLabel = "用户组" + case "platform_model": + scopeLabel = "平台模型" + } + scopeName := strings.TrimSpace(limitErr.ScopeName) + if scopeName == "" { + scopeName = strings.TrimSpace(limitErr.ScopeKey) + } + if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); limitErr.ScopeType == "user_group" && groupKey != "" && groupKey != scopeName { + scopeName = fmt.Sprintf("%s(%s)", scopeName, groupKey) + } + projected := limitErr.Projected + if projected <= 0 { + projected = limitErr.Current + limitErr.Amount + } + parts := []string{ + fmt.Sprintf("限流摘要:%s %s 的 %s 超限", scopeLabel, scopeName, limitErr.Metric), + fmt.Sprintf("当前 %s,本次 %s,预计 %s,限制 %s", formatRateLimitValue(limitErr.Current), formatRateLimitValue(limitErr.Amount), formatRateLimitValue(projected), formatRateLimitValue(limitErr.Limit)), + } + if limitErr.WindowSeconds > 0 { + parts = append(parts, fmt.Sprintf("窗口 %d 秒", limitErr.WindowSeconds)) + } + if limitErr.RetryAfter > 0 { + parts = append(parts, fmt.Sprintf("约%s后可重试", formatRateLimitDuration(limitErr.RetryAfter))) + } else if !limitErr.Retryable { + parts = append(parts, "该请求超过单次限额,不能排队重试") + } + return strings.Join(parts, ",") +} + +func rateLimitErrorDetail(err error) map[string]any { + var limitErr *store.RateLimitExceededError + if !errors.As(err, &limitErr) { + return nil + } + detail := map[string]any{ + "scopeType": limitErr.ScopeType, + "scopeKey": limitErr.ScopeKey, + "scopeName": limitErr.ScopeName, + "metric": limitErr.Metric, + "limit": limitErr.Limit, + "amount": limitErr.Amount, + "current": limitErr.Current, + "used": limitErr.Used, + "reserved": limitErr.Reserved, + "projected": limitErr.Projected, + "windowSeconds": limitErr.WindowSeconds, + "retryable": limitErr.Retryable, + "exceeded": map[string]any{ + "metric": limitErr.Metric, + "current": limitErr.Current, + "amount": limitErr.Amount, + "projected": limitErr.Projected, + "limit": limitErr.Limit, + }, + } + if limitErr.RetryAfter > 0 { + detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds() + } + if !limitErr.ResetAt.IsZero() { + detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano) + } + if len(limitErr.Policy) > 0 { + detail["rateLimitPolicy"] = limitErr.Policy + if matchedRule := matchedRateLimitRule(limitErr.Policy, limitErr.Metric); len(matchedRule) > 0 { + detail["matchedRule"] = matchedRule + } + } + if len(limitErr.ScopeMetadata) > 0 { + detail["scopeMetadata"] = limitErr.ScopeMetadata + } + if limitErr.ScopeType == "user_group" { + userGroup := map[string]any{ + "id": limitErr.ScopeKey, + "name": limitErr.ScopeName, + } + if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); groupKey != "" { + userGroup["groupKey"] = groupKey + } + detail["userGroup"] = userGroup + } + return detail +} + +func formatRateLimitValue(value float64) string { + return strconv.FormatFloat(value, 'f', -1, 64) +} + +func formatRateLimitDuration(duration time.Duration) string { + if duration < time.Second { + return strconv.FormatInt(duration.Milliseconds(), 10) + "毫秒" + } + seconds := duration.Seconds() + return strconv.FormatFloat(seconds, 'f', -1, 64) + "秒" +} + +func matchedRateLimitRule(policy map[string]any, metric string) map[string]any { + rules, _ := policy["rules"].([]any) + for _, rawRule := range rules { + rule, _ := rawRule.(map[string]any) + if stringValue(rule["metric"]) == metric { + return rule + } + } + return nil +} + func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) { user, ok := auth.UserFromContext(r.Context()) if !ok { diff --git a/apps/api/internal/httpapi/rate_limit_error_detail_test.go b/apps/api/internal/httpapi/rate_limit_error_detail_test.go new file mode 100644 index 0000000..fe1222b --- /dev/null +++ b/apps/api/internal/httpapi/rate_limit_error_detail_test.go @@ -0,0 +1,72 @@ +package httpapi + +import ( + "strings" + "testing" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestRateLimitErrorDetailIncludesUserGroupAndExceededMetric(t *testing.T) { + resetAt := time.Date(2026, 5, 15, 10, 30, 0, 0, time.UTC) + detail := rateLimitErrorDetail(&store.RateLimitExceededError{ + ScopeType: "user_group", + ScopeKey: "group-1", + ScopeName: "VIP 用户组", + ScopeMetadata: map[string]any{"groupKey": "vip"}, + Metric: "rpm", + Limit: 2, + Amount: 1, + Current: 2, + Used: 1, + Reserved: 1, + Projected: 3, + WindowSeconds: 60, + ResetAt: resetAt, + RetryAfter: 5 * time.Second, + Retryable: true, + Policy: map[string]any{ + "rules": []any{ + map[string]any{"metric": "rpm", "limit": float64(2), "windowSeconds": float64(60)}, + }, + }, + }) + if detail["metric"] != "rpm" || detail["projected"] != float64(3) || detail["limit"] != float64(2) { + t.Fatalf("unexpected exceeded detail: %+v", detail) + } + userGroup, _ := detail["userGroup"].(map[string]any) + if userGroup["id"] != "group-1" || userGroup["groupKey"] != "vip" || userGroup["name"] != "VIP 用户组" { + t.Fatalf("missing user group detail: %+v", detail) + } + matchedRule, _ := detail["matchedRule"].(map[string]any) + if matchedRule["metric"] != "rpm" { + t.Fatalf("missing matched rule: %+v", detail) + } + if detail["retryAfterMs"] != int64(5000) || detail["resetAt"] != resetAt.Format(time.RFC3339Nano) { + t.Fatalf("missing retry/reset detail: %+v", detail) + } +} + +func TestRunErrorMessageIncludesRateLimitSummary(t *testing.T) { + message := runErrorMessage(&store.RateLimitExceededError{ + ScopeType: "user_group", + ScopeKey: "group-1", + ScopeName: "VIP 用户组", + ScopeMetadata: map[string]any{"groupKey": "vip"}, + Metric: "rpm", + Limit: 2, + Amount: 1, + Current: 2, + Projected: 3, + WindowSeconds: 60, + RetryAfter: 5 * time.Second, + Retryable: true, + Message: "rate limit exceeded: rpm window has no remaining capacity", + }) + for _, expected := range []string{"限流摘要", "用户组 VIP 用户组(vip)", "rpm 超限", "当前 2", "本次 1", "预计 3", "限制 2", "窗口 60 秒", "约5秒后可重试"} { + if !strings.Contains(message, expected) { + t.Fatalf("message %q should contain %q", message, expected) + } + } +} diff --git a/apps/api/internal/httpapi/response.go b/apps/api/internal/httpapi/response.go index 7939d48..5f3a004 100644 --- a/apps/api/internal/httpapi/response.go +++ b/apps/api/internal/httpapi/response.go @@ -14,6 +14,10 @@ func writeJSON(w http.ResponseWriter, status int, value any) { } func writeError(w http.ResponseWriter, status int, message string, codes ...string) { + writeErrorWithDetails(w, status, message, nil, codes...) +} + +func writeErrorWithDetails(w http.ResponseWriter, status int, message string, details map[string]any, codes ...string) { errorPayload := map[string]any{ "message": message, "status": status, @@ -23,6 +27,9 @@ func writeError(w http.ResponseWriter, status int, message string, codes ...stri errorPayload["code"] = code } } + for key, value := range details { + errorPayload[key] = value + } writeJSON(w, status, map[string]any{"error": errorPayload}) } diff --git a/apps/api/internal/httpapi/server.go b/apps/api/internal/httpapi/server.go index 5cff51b..85256ca 100644 --- a/apps/api/internal/httpapi/server.go +++ b/apps/api/internal/httpapi/server.go @@ -41,6 +41,8 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.HandleFunc("GET /healthz", server.health) mux.HandleFunc("GET /readyz", server.ready) mux.HandleFunc("GET /static/simulation/{asset}", serveSimulationAsset) + mux.HandleFunc("GET /static/generated/{asset}", server.serveGeneratedStaticAsset) + mux.HandleFunc("GET /static/uploaded/{asset}", server.serveUploadedStaticAsset) mux.Handle("POST /api/v1/auth/register", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.register))) mux.Handle("POST /api/v1/auth/login", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.login))) @@ -102,6 +104,12 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("GET /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getRunnerPolicy))) mux.Handle("PATCH /api/admin/runtime/runner-policy", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRunnerPolicy))) mux.Handle("GET /api/admin/config/network-proxy", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getNetworkProxyConfig))) + mux.Handle("GET /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.getFileStorageSettings))) + mux.Handle("PATCH /api/admin/system/file-storage/settings", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageSettings))) + mux.Handle("GET /api/admin/system/file-storage/channels", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listFileStorageChannels))) + mux.Handle("POST /api/admin/system/file-storage/channels", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createFileStorageChannel))) + mux.Handle("PATCH /api/admin/system/file-storage/channels/{channelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateFileStorageChannel))) + mux.Handle("DELETE /api/admin/system/file-storage/channels/{channelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteFileStorageChannel))) mux.Handle("GET /api/admin/platforms", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPlatforms))) mux.Handle("POST /api/admin/platforms", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createPlatform))) mux.Handle("PATCH /api/admin/platforms/{platformID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updatePlatform))) @@ -123,6 +131,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false))) mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false))) mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false))) + mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile))) mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing))) @@ -135,6 +144,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("POST /v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", true))) mux.Handle("POST /images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", true))) mux.Handle("POST /v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", true))) + mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile))) return server.recover(server.cors(mux)) } diff --git a/apps/api/internal/httpapi/static_assets.go b/apps/api/internal/httpapi/static_assets.go new file mode 100644 index 0000000..72d8c3d --- /dev/null +++ b/apps/api/internal/httpapi/static_assets.go @@ -0,0 +1,37 @@ +package httpapi + +import ( + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" +) + +func (s *Server) serveGeneratedStaticAsset(w http.ResponseWriter, r *http.Request) { + s.serveLocalStaticAsset(w, r, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir) +} + +func (s *Server) serveUploadedStaticAsset(w http.ResponseWriter, r *http.Request) { + s.serveLocalStaticAsset(w, r, s.cfg.LocalUploadedStorageDir, config.DefaultLocalUploadedStorageDir) +} + +func (s *Server) serveLocalStaticAsset(w http.ResponseWriter, r *http.Request, storageDir string, fallbackStorageDir string) { + fileName := filepath.Base(strings.TrimSpace(r.PathValue("asset"))) + if fileName == "" || fileName == "." || fileName == ".." || fileName == string(filepath.Separator) { + http.NotFound(w, r) + return + } + storageDir = strings.TrimSpace(storageDir) + if storageDir == "" { + storageDir = fallbackStorageDir + } + filePath := filepath.Join(storageDir, fileName) + info, err := os.Stat(filePath) + if err != nil || info.IsDir() { + http.NotFound(w, r) + return + } + http.ServeFile(w, r, filePath) +} diff --git a/apps/api/internal/httpapi/static_assets_test.go b/apps/api/internal/httpapi/static_assets_test.go new file mode 100644 index 0000000..88e3944 --- /dev/null +++ b/apps/api/internal/httpapi/static_assets_test.go @@ -0,0 +1,65 @@ +package httpapi + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" +) + +func TestServeGeneratedStaticAsset(t *testing.T) { + storageDir := t.TempDir() + if err := os.WriteFile(filepath.Join(storageDir, "result.png"), []byte("png"), 0o644); err != nil { + t.Fatalf("failed to write generated asset fixture: %v", err) + } + server := &Server{cfg: config.Config{LocalGeneratedStorageDir: storageDir}} + request := httptest.NewRequest(http.MethodGet, "/static/generated/result.png", nil) + request.SetPathValue("asset", "result.png") + response := httptest.NewRecorder() + + server.serveGeneratedStaticAsset(response, request) + + if response.Code != http.StatusOK { + t.Fatalf("expected generated asset to be served, got status %d", response.Code) + } + if response.Body.String() != "png" { + t.Fatalf("unexpected generated asset payload: %q", response.Body.String()) + } +} + +func TestServeUploadedStaticAsset(t *testing.T) { + storageDir := t.TempDir() + if err := os.WriteFile(filepath.Join(storageDir, "upload.pdf"), []byte("pdf"), 0o644); err != nil { + t.Fatalf("failed to write uploaded asset fixture: %v", err) + } + server := &Server{cfg: config.Config{LocalUploadedStorageDir: storageDir}} + request := httptest.NewRequest(http.MethodGet, "/static/uploaded/upload.pdf", nil) + request.SetPathValue("asset", "upload.pdf") + response := httptest.NewRecorder() + + server.serveUploadedStaticAsset(response, request) + + if response.Code != http.StatusOK { + t.Fatalf("expected uploaded asset to be served, got status %d", response.Code) + } + if response.Body.String() != "pdf" { + t.Fatalf("unexpected uploaded asset payload: %q", response.Body.String()) + } +} + +func TestServeLocalStaticAssetRejectsTraversal(t *testing.T) { + storageDir := t.TempDir() + server := &Server{cfg: config.Config{LocalGeneratedStorageDir: storageDir}} + request := httptest.NewRequest(http.MethodGet, "/static/generated/..", nil) + request.SetPathValue("asset", "..") + response := httptest.NewRecorder() + + server.serveGeneratedStaticAsset(response, request) + + if response.Code != http.StatusNotFound { + t.Fatalf("expected traversal-like generated asset name to 404, got status %d", response.Code) + } +} diff --git a/apps/api/internal/httpapi/system_settings_handlers.go b/apps/api/internal/httpapi/system_settings_handlers.go new file mode 100644 index 0000000..649315d --- /dev/null +++ b/apps/api/internal/httpapi/system_settings_handlers.go @@ -0,0 +1,150 @@ +package httpapi + +import ( + "encoding/json" + "net/http" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func (s *Server) listFileStorageChannels(w http.ResponseWriter, r *http.Request) { + items, err := s.store.ListFileStorageChannels(r.Context()) + if err != nil { + s.logger.Error("list file storage channels failed", "error", err) + writeError(w, http.StatusInternalServerError, "list file storage channels failed") + return + } + writeJSON(w, http.StatusOK, map[string]any{"items": items}) +} + +func (s *Server) getFileStorageSettings(w http.ResponseWriter, r *http.Request) { + settings, err := s.store.GetFileStorageSettings(r.Context()) + if err != nil { + if store.IsUndefinedDatabaseObject(err) { + writeJSON(w, http.StatusOK, store.DefaultFileStorageSettings()) + return + } + s.logger.Error("get file storage settings failed", "error", err) + writeError(w, http.StatusInternalServerError, "get file storage settings failed") + return + } + writeJSON(w, http.StatusOK, settings) +} + +func (s *Server) updateFileStorageSettings(w http.ResponseWriter, r *http.Request) { + var input store.FileStorageSettingsInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid json body") + return + } + settings, err := s.store.UpdateFileStorageSettings(r.Context(), input) + if err != nil { + s.logger.Error("update file storage settings failed", "error", err) + writeError(w, http.StatusInternalServerError, "update file storage settings failed") + return + } + writeJSON(w, http.StatusOK, settings) +} + +func (s *Server) createFileStorageChannel(w http.ResponseWriter, r *http.Request) { + var input store.FileStorageChannelInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid json body") + return + } + if message := validateFileStorageChannelInput(input, nil); message != "" { + writeError(w, http.StatusBadRequest, message) + return + } + item, err := s.store.CreateFileStorageChannel(r.Context(), input) + if err != nil { + if store.IsUniqueViolation(err) { + writeError(w, http.StatusConflict, "file storage channel key already exists") + return + } + s.logger.Error("create file storage channel failed", "error", err) + writeError(w, http.StatusInternalServerError, "create file storage channel failed") + return + } + writeJSON(w, http.StatusCreated, item) +} + +func (s *Server) updateFileStorageChannel(w http.ResponseWriter, r *http.Request) { + var input store.FileStorageChannelInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + writeError(w, http.StatusBadRequest, "invalid json body") + return + } + existing, err := s.store.GetFileStorageChannel(r.Context(), r.PathValue("channelID")) + if err != nil { + if store.IsNotFound(err) { + writeError(w, http.StatusNotFound, "file storage channel not found") + return + } + s.logger.Error("get file storage channel failed", "error", err) + writeError(w, http.StatusInternalServerError, "get file storage channel failed") + return + } + if message := validateFileStorageChannelInput(input, &existing); message != "" { + writeError(w, http.StatusBadRequest, message) + return + } + item, err := s.store.UpdateFileStorageChannel(r.Context(), r.PathValue("channelID"), input) + if err != nil { + if store.IsNotFound(err) { + writeError(w, http.StatusNotFound, "file storage channel not found") + return + } + if store.IsUniqueViolation(err) { + writeError(w, http.StatusConflict, "file storage channel key already exists") + return + } + s.logger.Error("update file storage channel failed", "error", err) + writeError(w, http.StatusInternalServerError, "update file storage channel failed") + return + } + writeJSON(w, http.StatusOK, item) +} + +func (s *Server) deleteFileStorageChannel(w http.ResponseWriter, r *http.Request) { + if err := s.store.DeleteFileStorageChannel(r.Context(), r.PathValue("channelID")); err != nil { + if store.IsNotFound(err) { + writeError(w, http.StatusNotFound, "file storage channel not found") + return + } + s.logger.Error("delete file storage channel failed", "error", err) + writeError(w, http.StatusInternalServerError, "delete file storage channel failed") + return + } + w.WriteHeader(http.StatusNoContent) +} + +func validateFileStorageChannelInput(input store.FileStorageChannelInput, existing *store.FileStorageChannel) string { + provider := strings.ToLower(strings.TrimSpace(input.Provider)) + if provider == "" { + provider = "server_main_openapi" + } + status := strings.ToLower(strings.TrimSpace(input.Status)) + if status == "" { + status = "disabled" + } + if strings.TrimSpace(input.ChannelKey) == "" || strings.TrimSpace(input.Name) == "" { + return "channelKey and name are required" + } + if status != "enabled" && status != "disabled" { + return "status must be enabled or disabled" + } + if provider == "server_main_openapi" { + hasAPIKey := false + if input.APIKey != nil { + hasAPIKey = strings.TrimSpace(*input.APIKey) != "" + } else if existing != nil { + hasAPIKey = strings.TrimSpace(existing.APIKey) != "" + } + if status == "enabled" && !hasAPIKey { + return "server-main OpenAPI channel requires API key before enabling" + } + } + return "" +} diff --git a/apps/api/internal/runner/limits.go b/apps/api/internal/runner/limits.go index f7ddfc1..9a47484 100644 --- a/apps/api/internal/runner/limits.go +++ b/apps/api/internal/runner/limits.go @@ -52,9 +52,31 @@ func isLocalRateLimitError(err error) bool { func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation { out := make([]store.RateLimitReservation, 0) - out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...) + out = append(out, reservationsFromPolicy( + "platform_model", + candidate.PlatformModelID, + firstNonEmptyString(candidate.DisplayName, candidate.ModelAlias, candidate.ModelName), + map[string]any{ + "platformId": candidate.PlatformID, + "platformName": candidate.PlatformName, + "modelAlias": candidate.ModelAlias, + "modelName": candidate.ModelName, + }, + effectiveRateLimitPolicy(candidate), + body, + )...) if group, err := s.store.ResolveUserGroupPolicy(ctx, user); err == nil && group.ID != "" { - out = append(out, reservationsFromPolicy("user_group", group.ID, group.RateLimitPolicy, body)...) + out = append(out, reservationsFromPolicy( + "user_group", + group.ID, + firstNonEmptyString(group.Name, group.GroupKey), + map[string]any{ + "groupKey": group.GroupKey, + "name": group.Name, + }, + group.RateLimitPolicy, + body, + )...) } return out } @@ -90,7 +112,7 @@ func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any return policy } -func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation { +func reservationsFromPolicy(scopeType string, scopeKey string, scopeName string, scopeMetadata map[string]any, policy map[string]any, body map[string]any) []store.RateLimitReservation { if scopeKey == "" || !hasRules(policy) { return nil } @@ -108,11 +130,14 @@ func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string out = append(out, store.RateLimitReservation{ ScopeType: scopeType, ScopeKey: scopeKey, + ScopeName: scopeName, + ScopeMetadata: scopeMetadata, Metric: metric, Limit: limit, Amount: amount, WindowSeconds: int(floatFromAny(rule["windowSeconds"])), LeaseTTLSeconds: int(floatFromAny(rule["leaseTtlSeconds"])), + Policy: policy, }) } return out @@ -131,6 +156,11 @@ func estimateRequestTokens(body map[string]any) int { if input := stringFromMap(body, "input"); input != "" { text += input } + for _, item := range contentItems(body["content"]) { + if stringFromAny(item["type"]) == "text" { + text += stringFromAny(item["text"]) + } + } if messages, ok := body["messages"].([]any); ok { for _, raw := range messages { message, _ := raw.(map[string]any) diff --git a/apps/api/internal/runner/param_processor.go b/apps/api/internal/runner/param_processor.go index 26cdfcd..b366723 100644 --- a/apps/api/internal/runner/param_processor.go +++ b/apps/api/internal/runner/param_processor.go @@ -1,20 +1,19 @@ package runner import ( - "fmt" - "math" - "strconv" "strings" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) type paramProcessContext struct { + kind string modelCapability map[string]any candidate store.RuntimeModelCandidate log *parameterPreprocessingLog aspectRatio string resolution string + err error } type paramProcessor interface { @@ -30,6 +29,7 @@ type ParamProcessorChain struct { type parameterPreprocessResult struct { Body map[string]any Log parameterPreprocessingLog + Err error } type parameterPreprocessingLog struct { @@ -57,6 +57,7 @@ func NewParamProcessorChain() ParamProcessorChain { processors: []paramProcessor{ resolutionNormalizeProcessor{}, aspectRatioProcessor{}, + messageContentProcessor{}, contentFilterProcessor{}, inputAudioProcessor{}, durationProcessor{}, @@ -90,6 +91,7 @@ func preprocessRequestWithLog(kind string, body map[string]any, candidate store. }, } context := ¶mProcessContext{ + kind: kind, modelCapability: effectiveModelCapability(candidate), candidate: candidate, log: &log, @@ -101,7 +103,7 @@ func preprocessRequestWithLog(kind string, body map[string]any, candidate store. processed := chain.Process(params, modelType, context) log.Output = cloneMap(processed) log.Changed = len(log.Changes) > 0 - return parameterPreprocessResult{Body: processed, Log: log} + return parameterPreprocessResult{Body: processed, Log: log, Err: context.err} } func (chain ParamProcessorChain) Process(params map[string]any, modelType string, context *paramProcessContext) map[string]any { @@ -115,6 +117,9 @@ func (chain ParamProcessorChain) Process(params map[string]any, modelType string if !processor.Process(params, modelType, context) { break } + if context != nil && context.err != nil { + break + } } return params } @@ -135,6 +140,20 @@ func (context *paramProcessContext) recordChange(processor string, action string }) } +func (context *paramProcessContext) reject(processor string, path string, before any, reason string, capabilityPath string, capabilityValue any) bool { + if context != nil { + context.recordChange(processor, "reject", path, before, nil, reason, capabilityPath, capabilityValue) + context.err = parameterValidationError(reason) + } + return false +} + +type parameterValidationError string + +func (e parameterValidationError) Error() string { + return string(e) +} + func parameterPreprocessingMetrics(log parameterPreprocessingLog) map[string]any { return map[string]any{ "parameterPreprocessingSummary": parameterPreprocessingSummary(log), @@ -165,1319 +184,3 @@ func parameterPreprocessingSummary(log parameterPreprocessingLog) map[string]any } return summary } - -type resolutionNormalizeProcessor struct{} - -func (resolutionNormalizeProcessor) Name() string { return "ResolutionNormalizeProcessor" } - -func (resolutionNormalizeProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - if stringFromAny(params["resolution"]) != "" { - return false - } - size := stringFromAny(params["size"]) - if size == "" { - return false - } - return isImageResolution(modelType, size) || isVideoResolution(modelType, size) -} - -func (resolutionNormalizeProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - size := stringFromAny(params["size"]) - if stringFromAny(params["resolution"]) == "" && (isImageResolution(modelType, size) || isVideoResolution(modelType, size)) { - _, capabilityValue := capabilityEvidence(context.modelCapability, modelType, "output_resolutions") - params["resolution"] = size - context.resolution = size - context.recordChange( - "ResolutionNormalizeProcessor", - "set", - "resolution", - nil, - size, - "size 使用分辨率格式,归一到 resolution 供后续能力校验和计费使用。", - capabilityPath(modelType, "output_resolutions"), - capabilityValue, - ) - } - return true -} - -type aspectRatioProcessor struct{} - -func (aspectRatioProcessor) Name() string { return "AspectRatioProcessor" } - -func (aspectRatioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - return modelType != "text_generate" && (stringFromAny(params["aspect_ratio"]) != "" || stringFromAny(params["size"]) != "") -} - -func (aspectRatioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - capability := capabilityForType(context.modelCapability, modelType) - if capability == nil { - return true - } - - aspectRatio := stringFromAny(params["aspect_ratio"]) - if isEmptyParamString(aspectRatio) { - before := params["aspect_ratio"] - delete(params, "aspect_ratio") - context.aspectRatio = "" - context.recordChange( - "AspectRatioProcessor", - "remove", - "aspect_ratio", - before, - nil, - "aspect_ratio 是空值字符串,不能作为有效比例传给上游。", - "", - nil, - ) - return true - } - - resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution) - if resolution == "" { - if values := stringListFromAny(capability["output_resolutions"]); len(values) > 0 { - resolution = values[0] - } else if size := stringFromAny(params["size"]); strings.HasSuffix(size, "K") || strings.HasSuffix(size, "p") { - resolution = size - } - } - - allowed := aspectRatioAllowed(capability["aspect_ratio_allowed"], resolution) - if allowed != nil && len(allowed) == 1 && allowed[0] == "adaptive" { - before := params["aspect_ratio"] - params["aspect_ratio"] = "adaptive" - context.aspectRatio = "adaptive" - if before != "adaptive" { - context.recordChange( - "AspectRatioProcessor", - "adjust", - "aspect_ratio", - before, - "adaptive", - "模型当前分辨率只允许 adaptive 宽高比。", - capabilityPath(modelType, "aspect_ratio_allowed"), - capability["aspect_ratio_allowed"], - ) - } - return true - } - if allowed != nil && len(allowed) == 0 { - before := params["aspect_ratio"] - delete(params, "aspect_ratio") - context.aspectRatio = "" - context.recordChange( - "AspectRatioProcessor", - "remove", - "aspect_ratio", - before, - nil, - "模型能力配置不允许传入任何 aspect_ratio。", - capabilityPath(modelType, "aspect_ratio_allowed"), - capability["aspect_ratio_allowed"], - ) - return true - } - if aspectRatio == "" { - return true - } - if allowed == nil && validAspectRatio(aspectRatio) { - params["aspect_ratio"] = aspectRatio - context.aspectRatio = aspectRatio - return true - } - - processed, ok := validateAndAdjustAspectRatio(aspectRatio, capability, allowed) - if !ok { - before := params["aspect_ratio"] - delete(params, "aspect_ratio") - context.aspectRatio = "" - context.recordChange( - "AspectRatioProcessor", - "remove", - "aspect_ratio", - before, - nil, - "传入的 aspect_ratio 不在模型允许范围内,且没有可用替代值。", - capabilityPath(modelType, "aspect_ratio_allowed"), - capability["aspect_ratio_allowed"], - ) - return true - } - if processed != "" { - before := params["aspect_ratio"] - params["aspect_ratio"] = processed - context.aspectRatio = processed - if before != processed { - path := capabilityPath(modelType, "aspect_ratio_allowed") - value := capability["aspect_ratio_allowed"] - if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok { - ratio, valid := aspectRatioNumber(aspectRatio) - if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] { - path = capabilityPath(modelType, "aspect_ratio_range") - value = capability["aspect_ratio_range"] - } - } - context.recordChange( - "AspectRatioProcessor", - "adjust", - "aspect_ratio", - before, - processed, - "传入的 aspect_ratio 不符合模型能力配置,已调整为允许值。", - path, - value, - ) - } - } - return true -} - -type contentFilterProcessor struct{} - -func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" } - -func (contentFilterProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - _, ok := params["content"] - return ok -} - -func (contentFilterProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - content := contentItems(params["content"]) - if len(content) == 0 { - return true - } - - if isOmniVideoLike(context) { - filtered := filterUnsupportedOmniVideoContent(content, context) - params["content"] = mapsToAnySlice(filtered) - syncVideoConvenienceFields(params, filtered, context) - return true - } - - downgradeReferenceImageIfNeeded(params, content, modelType, context) - if modelType == "video_generate" || modelType == "text_to_video" { - next := make([]map[string]any, 0, len(content)) - for index, item := range content { - if isImageContent(item) { - reason, path, value := imageContentRemovalEvidence(item, modelType, context) - context.recordChange( - "ContentFilterProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - reason, - path, - value, - ) - continue - } - next = append(next, item) - } - content = next - } - if modelType == "image_to_video" || modelType == "omni_video" || modelType == "omni" { - if !supportsFirstAndLastFrame(context.modelCapability, modelType) { - next := make([]map[string]any, 0, len(content)) - for index, item := range content { - if stringFromAny(item["role"]) == "last_frame" { - context.recordChange( - "ContentFilterProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - "模型不支持首尾帧输入,已移除 last_frame。", - capabilityPath(modelType, "input_first_last_frame"), - map[string]any{ - "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), - "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), - }, - ) - continue - } - next = append(next, item) - } - content = next - deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"last_frame", "lastFrame"}, "模型不支持首尾帧输入,已移除快捷字段。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{ - "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), - "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), - }) - } - } - params["content"] = mapsToAnySlice(content) - return true -} - -func imageContentRemovalEvidence(item map[string]any, modelType string, context *paramProcessContext) (string, string, any) { - role := stringFromAny(item["role"]) - switch role { - case "first_frame": - return "模型能力未开启首帧输入,已移除 first_frame。", capabilityPath(modelType, "input_first_frame"), map[string]any{ - "input_first_frame": capabilityValue(context.modelCapability, modelType, "input_first_frame"), - "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), - } - case "last_frame": - return "模型能力未开启尾帧或首尾帧输入,已移除 last_frame。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{ - "input_last_frame": capabilityValue(context.modelCapability, modelType, "input_last_frame"), - "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), - "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), - "max_images_for_first_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_first_frame"), - "max_images_for_middle_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_middle_frame"), - } - case "reference_image": - return "模型能力未开启参考图输入,已移除 reference_image。", capabilityPath(modelType, "input_reference_generate_single"), map[string]any{ - "input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"), - "input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"), - "max_images": capabilityValue(context.modelCapability, modelType, "max_images"), - } - default: - return "当前模型能力未开启图像输入,已移除 image_url。", capabilityPath(modelType, "input_first_frame"), map[string]any{ - "input_first_frame": capabilityValue(context.modelCapability, modelType, "input_first_frame"), - "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), - "input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"), - "input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"), - } - } -} - -type inputAudioProcessor struct{} - -func (inputAudioProcessor) Name() string { return "InputAudioProcessor" } - -func (inputAudioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - if !isVideoModelType(modelType) { - return false - } - content := contentItems(params["content"]) - for _, item := range content { - if isAudioContent(item) { - return true - } - } - return false -} - -func (inputAudioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - content := contentItems(params["content"]) - if len(content) == 0 { - return true - } - supportsInputAudio := false - if len(context.modelCapability) > 0 { - if isOmniVideoLike(context) { - supportsInputAudio = supportsOmniAudioReference(context) - } else if capability := capabilityForType(context.modelCapability, modelType); capability != nil { - supportsInputAudio = boolFromAny(capability["input_audio"]) - } - } - if supportsInputAudio { - return true - } - next := make([]map[string]any, 0, len(content)) - for index, item := range content { - if isAudioContent(item) { - path, value := audioInputCapabilityEvidence(context, modelType) - context.recordChange( - "InputAudioProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - "模型能力未开启输入音频,已移除 audio_url。", - path, - value, - ) - continue - } - next = append(next, item) - } - params["content"] = mapsToAnySlice(next) - path, value := audioInputCapabilityEvidence(context, modelType) - deleteFieldsWithLog(params, context, "InputAudioProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "模型能力未开启输入音频,已移除音频参考快捷字段。", path, value) - return true -} - -type durationProcessor struct{} - -func (durationProcessor) Name() string { return "DurationProcessor" } - -func (durationProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - return isVideoModelType(modelType) && params["duration"] != nil -} - -func (durationProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - capability := capabilityForType(context.modelCapability, modelType) - if capability == nil { - return true - } - duration := floatFromAny(params["duration"]) - if duration <= 0 { - return true - } - resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution) - modeKey := videoModeKey(params) - if options := scopedNumberList(capability["duration_options"], resolution, modeKey); len(options) > 0 { - normalized := closestNumber(duration, options) - params["duration"] = normalized - syncDurationSeconds(params) - if normalized != duration { - context.recordChange( - "DurationProcessor", - "adjust", - "duration", - duration, - normalized, - "duration 不在模型固定时长选项内,已调整为最近的允许值。", - capabilityPath(modelType, "duration_options"), - capability["duration_options"], - ) - } - return true - } - if minValue, maxValue, ok := scopedRange(capability["duration_range"], resolution, modeKey); ok { - step := durationStep(capability["duration_step"], resolution, modeKey) - normalized := normalizeDurationByRange(duration, minValue, maxValue, step) - params["duration"] = normalized - syncDurationSeconds(params) - if normalized != duration { - context.recordChange( - "DurationProcessor", - "adjust", - "duration", - duration, - normalized, - "duration 超出模型时长范围或步进配置,已按能力配置归一。", - capabilityPath(modelType, "duration_range"), - map[string]any{ - "duration_range": capability["duration_range"], - "duration_step": capability["duration_step"], - }, - ) - } - } - return true -} - -type audioProcessor struct{} - -func (audioProcessor) Name() string { return "AudioProcessor" } - -func (audioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - return isVideoModelType(modelType) && (params["audio"] != nil || params["output_audio"] != nil) -} - -func (audioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - capability := capabilityForType(context.modelCapability, modelType) - if capability == nil || !boolFromAny(capability["output_audio"]) { - for _, key := range []string{"audio", "output_audio"} { - if before, ok := params[key]; ok { - delete(params, key) - context.recordChange( - "AudioProcessor", - "remove", - key, - before, - nil, - "模型能力未开启输出音频,已移除音频输出参数。", - capabilityPath(modelType, "output_audio"), - capabilityValue(context.modelCapability, modelType, "output_audio"), - ) - } - } - } - return true -} - -type imageCountProcessor struct{} - -func (imageCountProcessor) Name() string { return "ImageCountProcessor" } - -func (imageCountProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { - return modelType == "image_generate" || modelType == "image_edit" -} - -func (imageCountProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { - capability := capabilityForType(context.modelCapability, modelType) - if capability == nil || !boolFromAny(capability["output_multiple_images"]) { - return true - } - maxCount := int(math.Round(floatFromAny(capability["output_max_images_count"]))) - if maxCount <= 0 { - return true - } - count := int(math.Round(floatFromAny(params["n"]))) - if count <= 0 { - count = int(math.Round(floatFromAny(params["batch_size"]))) - } - if count <= 0 { - count = 1 - } - if count > maxCount { - before := count - count = maxCount - context.recordChange( - "ImageCountProcessor", - "adjust", - "n", - before, - count, - "请求图片数量超过模型输出上限,已按 output_max_images_count 截断。", - capabilityPath(modelType, "output_max_images_count"), - capability["output_max_images_count"], - ) - } - params["n"] = count - return true -} - -func ensureVideoContent(params map[string]any, context *paramProcessContext) { - if len(contentItems(params["content"])) > 0 { - return - } - content := make([]map[string]any, 0) - if prompt := firstNonEmptyString(stringFromAny(params["prompt"]), stringFromAny(params["input"])); prompt != "" { - content = append(content, map[string]any{"type": "text", "text": prompt}) - } - appendURL := func(kind string, role string, url string) { - url = strings.TrimSpace(url) - if url == "" { - return - } - item := map[string]any{"type": kind, "role": role} - switch kind { - case "image_url": - item["image_url"] = map[string]any{"url": url} - case "video_url": - item["video_url"] = map[string]any{"url": url} - case "audio_url": - item["audio_url"] = map[string]any{"url": url} - } - content = append(content, item) - } - - firstFrame := firstNonEmptyStringValue(params, "first_frame", "firstFrame") - appendURL("image_url", "first_frame", firstFrame) - appendURL("image_url", "last_frame", firstNonEmptyStringValue(params, "last_frame", "lastFrame")) - imageURLs := firstNonEmptyStringListFromAny(params["image"], params["images"], params["image_url"], params["imageUrl"], params["image_urls"], params["imageUrls"]) - if firstFrame == "" && len(imageURLs) > 0 { - appendURL("image_url", "first_frame", imageURLs[0]) - imageURLs = imageURLs[1:] - } - for _, url := range imageURLs { - appendURL("image_url", "reference_image", url) - } - for _, url := range firstNonEmptyStringListFromAny(params["reference_image"], params["referenceImage"]) { - appendURL("image_url", "reference_image", url) - } - for _, url := range firstNonEmptyStringListFromAny(params["video"], params["video_url"], params["videoUrl"], params["reference_video"], params["referenceVideo"]) { - appendURL("video_url", "reference_video", url) - } - for _, url := range firstNonEmptyStringListFromAny(params["audio_url"], params["audioUrl"], params["reference_audio"], params["referenceAudio"]) { - appendURL("audio_url", "reference_audio", url) - } - if len(content) > 0 { - params["content"] = mapsToAnySlice(content) - context.recordChange( - "ContentBuildProcessor", - "set", - "content", - nil, - params["content"], - "将 prompt/first_frame/reference_* 等快捷字段转换为 content 数组,后续处理器可按模型能力逐项过滤。", - "", - nil, - ) - } -} - -func effectiveModelCapability(candidate store.RuntimeModelCandidate) map[string]any { - base := cloneMap(candidate.Capabilities) - for key, value := range candidate.CapabilityOverride { - if baseChild, ok := base[key].(map[string]any); ok { - if overrideChild, ok := value.(map[string]any); ok { - base[key] = mergeMap(baseChild, overrideChild) - continue - } - } - base[key] = cloneAny(value) - } - return base -} - -func filterUnsupportedOmniVideoContent(content []map[string]any, context *paramProcessContext) []map[string]any { - capability := omniVideoCapability(context) - maxVideos := math.Inf(1) - if capability != nil { - if value, ok := numericField(capability, "max_videos"); ok { - maxVideos = value - } - } - maxAudios := 0.0 - if capability != nil { - if value, ok := numericField(capability, "max_audios"); ok { - maxAudios = value - } else if supportsOmniAudioReference(context) { - maxAudios = math.Inf(1) - } - } - - videoCount := 0.0 - audioCount := 0.0 - out := make([]map[string]any, 0, len(content)) - for index, item := range content { - if isVideoContent(item) { - if !supportsOmniVideoReference(item, capability) { - path, value := omniCapabilityEvidence(context, "supported_modes") - context.recordChange( - "ContentFilterProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - "视频参考类型不在 omni_video.supported_modes 允许范围内。", - path, - value, - ) - continue - } - if videoCount >= maxVideos { - path, value := omniCapabilityEvidence(context, "max_videos") - context.recordChange( - "ContentFilterProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - "视频参考数量超过 omni_video.max_videos 限制。", - path, - value, - ) - continue - } - videoCount++ - out = append(out, item) - continue - } - if isAudioContent(item) { - if !supportsOmniAudioReference(context) { - path, value := omniCapabilityEvidence(context, "input_audio") - context.recordChange( - "ContentFilterProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - "模型能力不支持音频参考,已移除 audio_url。", - path, - mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")), - ) - continue - } - if audioCount >= maxAudios { - path, value := omniCapabilityEvidence(context, "max_audios") - context.recordChange( - "ContentFilterProcessor", - "remove", - fmt.Sprintf("content[%d]", index), - item, - nil, - "音频参考数量超过 omni_video.max_audios 限制。", - path, - value, - ) - continue - } - audioCount++ - out = append(out, item) - continue - } - out = append(out, item) - } - return out -} - -func isOmniVideoLike(context *paramProcessContext) bool { - modelType := strings.TrimSpace(context.candidate.ModelType) - return modelType == "omni_video" || - modelType == "omni" || - context.modelCapability["omni_video"] != nil || - context.modelCapability["omni"] != nil -} - -func omniVideoCapability(context *paramProcessContext) map[string]any { - if capability := capabilityForType(context.modelCapability, "omni_video"); capability != nil { - return capability - } - return capabilityForType(context.modelCapability, "omni") -} - -func supportsOmniAudioReference(context *paramProcessContext) bool { - capability := omniVideoCapability(context) - return capability != nil && (boolFromAny(capability["input_audio"]) || floatFromAny(capability["max_audios"]) > 0) -} - -func supportsOmniVideoReference(item map[string]any, capability map[string]any) bool { - if capability == nil { - return true - } - if value, ok := numericField(capability, "max_videos"); ok && value == 0 { - return false - } - supportedModes := stringListFromAny(capability["supported_modes"]) - supportsReference := containsString(supportedModes, "video_reference") - supportsEdit := containsString(supportedModes, "video_edit") - video, _ := item["video_url"].(map[string]any) - referType := stringFromAny(video["refer_type"]) - isEditVideo := stringFromAny(item["role"]) == "video_base" || referType == "base" - isReferenceVideo := stringFromAny(item["role"]) == "video_feature" || - stringFromAny(item["role"]) == "reference_video" || - referType == "feature" - if isEditVideo { - return supportsEdit - } - if isReferenceVideo { - return supportsReference - } - return supportsReference || supportsEdit -} - -func downgradeReferenceImageIfNeeded(params map[string]any, content []map[string]any, modelType string, context *paramProcessContext) { - if modelType != "image_to_video" && modelType != "video_generate" && modelType != "video_edit" && modelType != "omni_video" && modelType != "omni" { - return - } - if supportsReferenceImage(context.modelCapability, modelType) { - return - } - count := 0 - for index, item := range content { - if stringFromAny(item["type"]) == "image_url" && stringFromAny(item["role"]) == "reference_image" { - before := cloneMap(item) - item["role"] = "first_frame" - context.recordChange( - "ContentFilterProcessor", - "adjust", - fmt.Sprintf("content[%d].role", index), - before, - item, - "模型不支持 reference_image,已降级为 first_frame。", - capabilityPath(modelType, "input_reference_generate_single"), - map[string]any{ - "input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"), - "input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"), - "max_images": capabilityValue(context.modelCapability, modelType, "max_images"), - }, - ) - count++ - } - } - if count > 0 { - appendParamWarning(params, "reference_image is unsupported by the selected model and was downgraded to first_frame") - } -} - -func supportsReferenceImage(modelCapability map[string]any, modelType string) bool { - candidates := []map[string]any{} - if capability := capabilityForType(modelCapability, modelType); capability != nil { - candidates = append(candidates, capability) - } - if modelType != "image_to_video" { - if capability := capabilityForType(modelCapability, "image_to_video"); capability != nil { - candidates = append(candidates, capability) - } - } - if len(candidates) == 0 { - return true - } - for _, capability := range candidates { - _, hasSingle := capability["input_reference_generate_single"] - _, hasMultiple := capability["input_reference_generate_multiple"] - if hasSingle || hasMultiple { - if boolFromAny(capability["input_reference_generate_single"]) || boolFromAny(capability["input_reference_generate_multiple"]) { - return true - } - continue - } - if value, ok := numericField(capability, "max_images"); ok { - if value > 1 { - return true - } - continue - } - } - return false -} - -func supportsFirstAndLastFrame(modelCapability map[string]any, modelType string) bool { - capability := capabilityForType(modelCapability, modelType) - if capability == nil { - return false - } - return boolFromAny(capability["input_first_last_frame"]) || floatFromAny(capability["max_images_for_last_frame"]) > 0 -} - -func validateAndAdjustAspectRatio(aspectRatio string, capability map[string]any, allowed []string) (string, bool) { - if !isMediaModelTypeWithAspectRatio(capability) { - return "", false - } - if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok { - ratio, valid := aspectRatioNumber(aspectRatio) - if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] { - return adjustAspectRatioToRange(aspectRatio, ratioRange[0], ratioRange[1], allowed), true - } - } - if allowed == nil { - return aspectRatio, true - } - if len(allowed) == 0 { - return "", false - } - if (aspectRatio == "adaptive" || aspectRatio == "keep_ratio") && !containsString(allowed, aspectRatio) { - return "", false - } - if containsString(allowed, aspectRatio) { - return aspectRatio, true - } - return allowed[0], true -} - -func isMediaModelTypeWithAspectRatio(capability map[string]any) bool { - return capability != nil -} - -func aspectRatioAllowed(value any, resolution string) []string { - switch typed := value.(type) { - case []any: - return stringListFromAny(typed) - case []string: - return typed - case map[string]any: - if resolution != "" { - if values := stringListFromAny(typed[resolution]); len(values) > 0 { - return values - } - } - return nil - default: - return nil - } -} - -func scopedNumberList(value any, scopes ...string) []float64 { - switch typed := value.(type) { - case []any: - out := make([]float64, 0, len(typed)) - for _, item := range typed { - if number := floatFromAny(item); number > 0 { - out = append(out, number) - } - } - return out - case []float64: - return typed - case []int: - out := make([]float64, 0, len(typed)) - for _, item := range typed { - out = append(out, float64(item)) - } - return out - case map[string]any: - for _, scope := range scopes { - if scope == "" { - continue - } - if values := scopedNumberList(typed[scope]); len(values) > 0 { - return values - } - } - for _, item := range typed { - if values := scopedNumberList(item); len(values) > 0 { - return values - } - } - } - return nil -} - -func scopedRange(value any, scopes ...string) (float64, float64, bool) { - if pair, ok := numberPair(value); ok { - return pair[0], pair[1], true - } - if typed, ok := value.(map[string]any); ok { - for _, scope := range scopes { - if scope == "" { - continue - } - if minValue, maxValue, ok := scopedRange(typed[scope]); ok { - return minValue, maxValue, true - } - } - for _, item := range typed { - if minValue, maxValue, ok := scopedRange(item); ok { - return minValue, maxValue, true - } - } - } - return 0, 0, false -} - -func durationStep(value any, scopes ...string) float64 { - if step := floatFromAny(value); step > 0 { - return step - } - if typed, ok := value.(map[string]any); ok { - for _, scope := range scopes { - if scope == "" { - continue - } - if step := durationStep(typed[scope]); step > 0 { - return step - } - } - for _, item := range typed { - if step := durationStep(item); step > 0 { - return step - } - } - } - return 0 -} - -func normalizeDurationByRange(target float64, minValue float64, maxValue float64, step float64) float64 { - clamped := math.Min(math.Max(target, minValue), maxValue) - if step <= 0 { - return clamped - } - snapped := math.Round((clamped-minValue)/step)*step + minValue - return math.Round(snapped*1_000_000) / 1_000_000 -} - -func closestNumber(target float64, values []float64) float64 { - if len(values) == 0 { - return target - } - closest := values[0] - minDiff := math.Abs(target - closest) - for _, value := range values[1:] { - diff := math.Abs(target - value) - if diff < minDiff { - minDiff = diff - closest = value - } - } - return closest -} - -func videoModeKey(params map[string]any) string { - content := contentItems(params["content"]) - hasFirstFrame := false - hasLastFrame := false - for _, item := range content { - switch stringFromAny(item["role"]) { - case "first_frame": - hasFirstFrame = true - case "last_frame": - hasLastFrame = true - } - } - switch { - case hasFirstFrame && hasLastFrame: - return "input_first_last_frame" - case hasFirstFrame: - return "input_first_frame" - case hasLastFrame: - return "input_last_frame" - default: - return "" - } -} - -func syncDurationSeconds(params map[string]any) { - if params["duration_seconds"] != nil { - params["duration_seconds"] = params["duration"] - } -} - -func syncVideoConvenienceFields(params map[string]any, content []map[string]any, context *paramProcessContext) { - hasVideo := false - hasAudio := false - for _, item := range content { - hasVideo = hasVideo || isVideoContent(item) - hasAudio = hasAudio || isAudioContent(item) - } - if !hasVideo { - path, value := omniCapabilityEvidence(context, "supported_modes") - deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"video", "video_url", "videoUrl", "reference_video", "referenceVideo"}, "对应视频 content 已被模型能力过滤,移除视频参考快捷字段。", path, value) - } - if !hasAudio { - path, value := omniCapabilityEvidence(context, "input_audio") - deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "对应音频 content 已被模型能力过滤,移除音频参考快捷字段。", path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios"))) - } -} - -func deleteFieldsWithLog(params map[string]any, context *paramProcessContext, processor string, keys []string, reason string, capabilityPath string, capabilityValue any) { - for _, key := range keys { - if before, ok := params[key]; ok { - delete(params, key) - context.recordChange(processor, "remove", key, before, nil, reason, capabilityPath, capabilityValue) - } - } -} - -func appendParamWarning(params map[string]any, warning string) { - warnings, _ := params["_param_warnings"].([]any) - for _, item := range warnings { - if stringFromAny(item) == warning { - return - } - } - params["_param_warnings"] = append(warnings, warning) -} - -func filterContent(content []map[string]any, keep func(map[string]any) bool) []map[string]any { - out := make([]map[string]any, 0, len(content)) - for _, item := range content { - if keep(item) { - out = append(out, item) - } - } - return out -} - -func contentItems(value any) []map[string]any { - switch typed := value.(type) { - case []any: - out := make([]map[string]any, 0, len(typed)) - for _, item := range typed { - if object, ok := item.(map[string]any); ok { - out = append(out, cloneMap(object)) - } - } - return out - case []map[string]any: - out := make([]map[string]any, 0, len(typed)) - for _, item := range typed { - out = append(out, cloneMap(item)) - } - return out - default: - return nil - } -} - -func mapsToAnySlice(values []map[string]any) []any { - out := make([]any, 0, len(values)) - for _, value := range values { - out = append(out, value) - } - return out -} - -func isImageContent(item map[string]any) bool { - return stringFromAny(item["type"]) == "image_url" || item["image_url"] != nil -} - -func isVideoContent(item map[string]any) bool { - return stringFromAny(item["type"]) == "video_url" || item["video_url"] != nil -} - -func isAudioContent(item map[string]any) bool { - return stringFromAny(item["type"]) == "audio_url" || item["audio_url"] != nil -} - -func capabilityForType(capabilities map[string]any, modelType string) map[string]any { - if capabilities == nil { - return nil - } - if typed, ok := capabilities[modelType].(map[string]any); ok { - return typed - } - return nil -} - -func capabilityPath(modelType string, key string) string { - modelType = strings.TrimSpace(modelType) - if modelType == "" { - return "" - } - if strings.TrimSpace(key) == "" { - return "capabilities." + modelType - } - return "capabilities." + modelType + "." + key -} - -func capabilityValue(capabilities map[string]any, modelType string, key string) any { - capability := capabilityForType(capabilities, modelType) - if capability == nil { - return nil - } - return cloneAny(capability[key]) -} - -func capabilityEvidence(capabilities map[string]any, modelType string, key string) (string, any) { - return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key) -} - -func audioInputCapabilityEvidence(context *paramProcessContext, modelType string) (string, any) { - if isOmniVideoLike(context) { - path, value := omniCapabilityEvidence(context, "input_audio") - return path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")) - } - return capabilityEvidence(context.modelCapability, modelType, "input_audio") -} - -func omniCapabilityType(context *paramProcessContext) string { - if context != nil && capabilityForType(context.modelCapability, "omni_video") != nil { - return "omni_video" - } - if context != nil && capabilityForType(context.modelCapability, "omni") != nil { - return "omni" - } - return "omni_video" -} - -func omniCapabilityEvidence(context *paramProcessContext, key string) (string, any) { - modelType := omniCapabilityType(context) - var capabilities map[string]any - if context != nil { - capabilities = context.modelCapability - } - return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key) -} - -func omniCapabilityBundle(context *paramProcessContext, keys ...string) map[string]any { - modelType := omniCapabilityType(context) - var capabilities map[string]any - if context != nil { - capabilities = context.modelCapability - } - out := map[string]any{} - for _, key := range keys { - out[key] = capabilityValue(capabilities, modelType, key) - } - return out -} - -func numericField(values map[string]any, key string) (float64, bool) { - if values == nil { - return 0, false - } - if _, ok := values[key]; !ok { - return 0, false - } - return floatFromAny(values[key]), true -} - -func boolFromAny(value any) bool { - typed, _ := value.(bool) - return typed -} - -func firstNonEmptyStringValue(values map[string]any, keys ...string) string { - for _, key := range keys { - if value := stringFromAny(values[key]); value != "" { - return value - } - } - return "" -} - -func firstNonEmptyStringListFromAny(values ...any) []string { - for _, value := range values { - items := stringListFromAny(value) - if len(items) > 0 { - return items - } - } - return nil -} - -func stringListFromAny(value any) []string { - switch typed := value.(type) { - case []string: - out := make([]string, 0, len(typed)) - for _, item := range typed { - if text := strings.TrimSpace(item); text != "" { - out = append(out, text) - } - } - return out - case []any: - out := make([]string, 0, len(typed)) - for _, item := range typed { - if text := stringFromAny(item); text != "" { - out = append(out, text) - } - } - return out - case string: - if strings.TrimSpace(typed) == "" { - return nil - } - return []string{strings.TrimSpace(typed)} - default: - return nil - } -} - -func containsString(values []string, target string) bool { - for _, value := range values { - if value == target { - return true - } - } - return false -} - -func appendUniqueString(values *[]string, value string) { - value = strings.TrimSpace(value) - if value == "" { - return - } - for _, existing := range *values { - if existing == value { - return - } - } - *values = append(*values, value) -} - -func numberPair(value any) ([2]float64, bool) { - switch typed := value.(type) { - case []any: - if len(typed) < 2 { - return [2]float64{}, false - } - return [2]float64{floatFromAny(typed[0]), floatFromAny(typed[1])}, true - case []float64: - if len(typed) < 2 { - return [2]float64{}, false - } - return [2]float64{typed[0], typed[1]}, true - case []int: - if len(typed) < 2 { - return [2]float64{}, false - } - return [2]float64{float64(typed[0]), float64(typed[1])}, true - default: - return [2]float64{}, false - } -} - -func validAspectRatio(value string) bool { - if value == "adaptive" || value == "keep_ratio" { - return true - } - _, ok := aspectRatioNumber(value) - return ok -} - -func aspectRatioNumber(value string) (float64, bool) { - parts := strings.Split(value, ":") - if len(parts) != 2 { - return 0, false - } - width := parsePositiveFloat(parts[0]) - height := parsePositiveFloat(parts[1]) - if width <= 0 || height <= 0 { - return 0, false - } - return width / height, true -} - -func adjustAspectRatioToRange(value string, minValue float64, maxValue float64, allowed []string) string { - current, ok := aspectRatioNumber(value) - if !ok { - if len(allowed) > 0 { - return allowed[0] - } - return "1:1" - } - if len(allowed) > 0 { - closest := "" - minDiff := math.Inf(1) - for _, candidate := range allowed { - ratio, ok := aspectRatioNumber(candidate) - if !ok || ratio < minValue || ratio > maxValue { - continue - } - diff := math.Abs(ratio - current) - if diff < minDiff { - minDiff = diff - closest = candidate - } - } - if closest != "" { - return closest - } - } - if current < minValue { - return ratioString(minValue) - } - return ratioString(maxValue) -} - -func ratioString(value float64) string { - if value <= 0 { - return "1:1" - } - return strings.TrimRight(strings.TrimRight(strconv.FormatFloat(value, 'f', 6, 64), "0"), ".") + ":1" -} - -func parsePositiveFloat(value string) float64 { - for _, r := range strings.TrimSpace(value) { - if r < '0' || r > '9' { - if r != '.' { - return 0 - } - } - } - out, _ := strconv.ParseFloat(strings.TrimSpace(value), 64) - return out -} - -func isEmptyParamString(value string) bool { - normalized := strings.ToLower(strings.TrimSpace(value)) - return normalized == "null" || normalized == "undefined" -} - -func isImageResolution(modelType string, value string) bool { - return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "4K", "8K"}, value) -} - -func isVideoResolution(modelType string, value string) bool { - return isVideoModelType(modelType) && containsString([]string{"480p", "720p", "1080p", "1440p", "2160p"}, value) -} - -func isVideoModelType(modelType string) bool { - return modelType == "video_generate" || modelType == "text_to_video" || modelType == "image_to_video" || modelType == "video_edit" || modelType == "omni_video" || modelType == "omni" -} - -func cloneMap(values map[string]any) map[string]any { - out := map[string]any{} - for key, value := range values { - out[key] = cloneAny(value) - } - return out -} - -func cloneAny(value any) any { - switch typed := value.(type) { - case map[string]any: - return cloneMap(typed) - case []any: - out := make([]any, 0, len(typed)) - for _, item := range typed { - out = append(out, cloneAny(item)) - } - return out - case []map[string]any: - out := make([]any, 0, len(typed)) - for _, item := range typed { - out = append(out, cloneMap(item)) - } - return out - default: - return value - } -} diff --git a/apps/api/internal/runner/param_processor_media.go b/apps/api/internal/runner/param_processor_media.go new file mode 100644 index 0000000..10b94cd --- /dev/null +++ b/apps/api/internal/runner/param_processor_media.go @@ -0,0 +1,380 @@ +package runner + +import ( + "fmt" + "math" + "strings" +) + +type resolutionNormalizeProcessor struct{} + +func (resolutionNormalizeProcessor) Name() string { return "ResolutionNormalizeProcessor" } + +func (resolutionNormalizeProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + if stringFromAny(params["resolution"]) != "" { + return false + } + size := stringFromAny(params["size"]) + if size == "" { + return false + } + return isImageResolution(modelType, size) || isVideoResolution(modelType, size) +} + +func (resolutionNormalizeProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + size := stringFromAny(params["size"]) + if stringFromAny(params["resolution"]) == "" && (isImageResolution(modelType, size) || isVideoResolution(modelType, size)) { + _, capabilityValue := capabilityEvidence(context.modelCapability, modelType, "output_resolutions") + params["resolution"] = size + context.resolution = size + context.recordChange( + "ResolutionNormalizeProcessor", + "set", + "resolution", + nil, + size, + "size 使用分辨率格式,归一到 resolution 供后续能力校验和计费使用。", + capabilityPath(modelType, "output_resolutions"), + capabilityValue, + ) + } + return true +} + +type aspectRatioProcessor struct{} + +func (aspectRatioProcessor) Name() string { return "AspectRatioProcessor" } + +func (aspectRatioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return modelType != "text_generate" && (stringFromAny(params["aspect_ratio"]) != "" || stringFromAny(params["size"]) != "") +} + +func (aspectRatioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil { + return true + } + + aspectRatio := stringFromAny(params["aspect_ratio"]) + if isEmptyParamString(aspectRatio) { + before := params["aspect_ratio"] + delete(params, "aspect_ratio") + context.aspectRatio = "" + context.recordChange( + "AspectRatioProcessor", + "remove", + "aspect_ratio", + before, + nil, + "aspect_ratio 是空值字符串,不能作为有效比例传给上游。", + "", + nil, + ) + return true + } + + resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution) + if resolution == "" { + if values := stringListFromAny(capability["output_resolutions"]); len(values) > 0 { + resolution = values[0] + } else if size := stringFromAny(params["size"]); strings.HasSuffix(size, "K") || strings.HasSuffix(size, "p") { + resolution = size + } + } + + allowed := aspectRatioAllowed(capability["aspect_ratio_allowed"], resolution) + if allowed != nil && len(allowed) == 1 && allowed[0] == "adaptive" { + before := params["aspect_ratio"] + params["aspect_ratio"] = "adaptive" + context.aspectRatio = "adaptive" + if before != "adaptive" { + context.recordChange( + "AspectRatioProcessor", + "adjust", + "aspect_ratio", + before, + "adaptive", + "模型当前分辨率只允许 adaptive 宽高比。", + capabilityPath(modelType, "aspect_ratio_allowed"), + capability["aspect_ratio_allowed"], + ) + } + return true + } + if allowed != nil && len(allowed) == 0 { + before := params["aspect_ratio"] + delete(params, "aspect_ratio") + context.aspectRatio = "" + context.recordChange( + "AspectRatioProcessor", + "remove", + "aspect_ratio", + before, + nil, + "模型能力配置不允许传入任何 aspect_ratio。", + capabilityPath(modelType, "aspect_ratio_allowed"), + capability["aspect_ratio_allowed"], + ) + return true + } + if aspectRatio == "" { + return true + } + if allowed == nil && validAspectRatio(aspectRatio) { + params["aspect_ratio"] = aspectRatio + context.aspectRatio = aspectRatio + return true + } + + processed, ok := validateAndAdjustAspectRatio(aspectRatio, capability, allowed) + if !ok { + before := params["aspect_ratio"] + delete(params, "aspect_ratio") + context.aspectRatio = "" + context.recordChange( + "AspectRatioProcessor", + "remove", + "aspect_ratio", + before, + nil, + "传入的 aspect_ratio 不在模型允许范围内,且没有可用替代值。", + capabilityPath(modelType, "aspect_ratio_allowed"), + capability["aspect_ratio_allowed"], + ) + return true + } + if processed != "" { + before := params["aspect_ratio"] + params["aspect_ratio"] = processed + context.aspectRatio = processed + if before != processed { + path := capabilityPath(modelType, "aspect_ratio_allowed") + value := capability["aspect_ratio_allowed"] + if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok { + ratio, valid := aspectRatioNumber(aspectRatio) + if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] { + path = capabilityPath(modelType, "aspect_ratio_range") + value = capability["aspect_ratio_range"] + } + } + context.recordChange( + "AspectRatioProcessor", + "adjust", + "aspect_ratio", + before, + processed, + "传入的 aspect_ratio 不符合模型能力配置,已调整为允许值。", + path, + value, + ) + } + } + return true +} + +type inputAudioProcessor struct{} + +func (inputAudioProcessor) Name() string { return "InputAudioProcessor" } + +func (inputAudioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + if !isVideoModelType(modelType) { + return false + } + content := contentItems(params["content"]) + for _, item := range content { + if isAudioContent(item) { + return true + } + } + return false +} + +func (inputAudioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + content := contentItems(params["content"]) + if len(content) == 0 { + return true + } + supportsInputAudio := false + if len(context.modelCapability) > 0 { + if isOmniVideoLike(context) { + supportsInputAudio = supportsOmniAudioReference(context) + } else if capability := capabilityForType(context.modelCapability, modelType); capability != nil { + supportsInputAudio = boolFromAny(capability["input_audio"]) + } + } + if supportsInputAudio { + return true + } + next := make([]map[string]any, 0, len(content)) + for index, item := range content { + if isAudioContent(item) { + path, value := audioInputCapabilityEvidence(context, modelType) + context.recordChange( + "InputAudioProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "模型能力未开启输入音频,已移除 audio_url。", + path, + value, + ) + continue + } + next = append(next, item) + } + params["content"] = mapsToAnySlice(next) + path, value := audioInputCapabilityEvidence(context, modelType) + deleteFieldsWithLog(params, context, "InputAudioProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "模型能力未开启输入音频,已移除音频参考快捷字段。", path, value) + return true +} + +type durationProcessor struct{} + +func (durationProcessor) Name() string { return "DurationProcessor" } + +func (durationProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return isVideoModelType(modelType) && params["duration"] != nil +} + +func (durationProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil { + return true + } + duration := floatFromAny(params["duration"]) + if duration <= 0 { + return true + } + resolution := firstNonEmptyString(stringFromAny(params["resolution"]), context.resolution) + modeKey := videoModeKey(params) + if options := scopedNumberList(capability["duration_options"], resolution, modeKey); len(options) > 0 { + normalized := nextAllowedNumber(duration, options) + params["duration"] = normalized + syncDurationSeconds(params) + if normalized != duration { + context.recordChange( + "DurationProcessor", + "adjust", + "duration", + duration, + normalized, + "duration 不在模型固定时长选项内,已向上调整为允许值。", + capabilityPath(modelType, "duration_options"), + capability["duration_options"], + ) + } + return true + } + if minValue, maxValue, ok := scopedRange(capability["duration_range"], resolution, modeKey); ok { + step := durationStep(capability["duration_step"], resolution, modeKey) + normalized := normalizeDurationByRange(duration, minValue, maxValue, step) + params["duration"] = normalized + syncDurationSeconds(params) + if normalized != duration { + context.recordChange( + "DurationProcessor", + "adjust", + "duration", + duration, + normalized, + "duration 超出模型时长范围或步进配置,已按能力配置归一。", + capabilityPath(modelType, "duration_range"), + map[string]any{ + "duration_range": capability["duration_range"], + "duration_step": capability["duration_step"], + }, + ) + } + return true + } + step := durationStep(capability["duration_step"], resolution, modeKey) + normalized := normalizeDurationByStep(duration, step) + params["duration"] = normalized + syncDurationSeconds(params) + if normalized != duration { + context.recordChange( + "DurationProcessor", + "adjust", + "duration", + duration, + normalized, + "duration 不符合模型时长步进,已按步进向上归一。", + capabilityPath(modelType, "duration_step"), + capability["duration_step"], + ) + } + return true +} + +type audioProcessor struct{} + +func (audioProcessor) Name() string { return "AudioProcessor" } + +func (audioProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return isVideoModelType(modelType) && (params["audio"] != nil || params["output_audio"] != nil) +} + +func (audioProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil || !boolFromAny(capability["output_audio"]) { + for _, key := range []string{"audio", "output_audio"} { + if before, ok := params[key]; ok { + delete(params, key) + context.recordChange( + "AudioProcessor", + "remove", + key, + before, + nil, + "模型能力未开启输出音频,已移除音频输出参数。", + capabilityPath(modelType, "output_audio"), + capabilityValue(context.modelCapability, modelType, "output_audio"), + ) + } + } + } + return true +} + +type imageCountProcessor struct{} + +func (imageCountProcessor) Name() string { return "ImageCountProcessor" } + +func (imageCountProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return modelType == "image_generate" || modelType == "image_edit" +} + +func (imageCountProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + capability := capabilityForType(context.modelCapability, modelType) + if capability == nil || !boolFromAny(capability["output_multiple_images"]) { + return true + } + maxCount := int(math.Round(floatFromAny(capability["output_max_images_count"]))) + if maxCount <= 0 { + return true + } + count := int(math.Round(floatFromAny(params["n"]))) + if count <= 0 { + count = int(math.Round(floatFromAny(params["batch_size"]))) + } + if count <= 0 { + count = 1 + } + if count > maxCount { + before := count + count = maxCount + context.recordChange( + "ImageCountProcessor", + "adjust", + "n", + before, + count, + "请求图片数量超过模型输出上限,已按 output_max_images_count 截断。", + capabilityPath(modelType, "output_max_images_count"), + capability["output_max_images_count"], + ) + } + params["n"] = count + return true +} diff --git a/apps/api/internal/runner/param_processor_message.go b/apps/api/internal/runner/param_processor_message.go new file mode 100644 index 0000000..8c18666 --- /dev/null +++ b/apps/api/internal/runner/param_processor_message.go @@ -0,0 +1,190 @@ +package runner + +import "fmt" + +type messageContentProcessor struct{} + +func (messageContentProcessor) Name() string { return "MessageContentProcessor" } + +func (messageContentProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + return isTextGenerationKind(context.kind) && params["messages"] != nil +} + +func (messageContentProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + messages, changed := processMessageListContent(params["messages"], context) + if changed { + params["messages"] = messages + } + return true +} + +func processMessageListContent(value any, context *paramProcessContext) ([]any, bool) { + rawMessages, ok := value.([]any) + if !ok { + return nil, false + } + out := make([]any, 0, len(rawMessages)) + changed := false + for messageIndex, rawMessage := range rawMessages { + message, ok := rawMessage.(map[string]any) + if !ok { + out = append(out, rawMessage) + continue + } + nextMessage := cloneMap(message) + if contentParts, ok := message["content"].([]any); ok { + nextContent, contentChanged := processMessageContentParts( + contentParts, + fmt.Sprintf("messages[%d].content", messageIndex), + context, + ) + if contentChanged { + nextMessage["content"] = nextContent + changed = true + } + } + out = append(out, nextMessage) + } + return out, changed +} + +func processMessageContentParts(parts []any, basePath string, context *paramProcessContext) ([]any, bool) { + out := make([]any, 0, len(parts)) + changed := false + for partIndex, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + out = append(out, rawPart) + continue + } + if replacement, replacementChanged := messageContentPartReplacement(part, context); replacementChanged { + out = append(out, replacement) + context.recordChange( + "MessageContentProcessor", + "convert", + fmt.Sprintf("%s[%d]", basePath, partIndex), + part, + replacement, + messageContentConversionReason(part), + messageContentCapabilityPath(part), + messageContentCapabilityValue(part, context), + ) + changed = true + continue + } + out = append(out, cloneMap(part)) + } + return out, changed +} + +func messageContentPartReplacement(part map[string]any, context *paramProcessContext) (map[string]any, bool) { + switch { + case isImageContent(part): + if modelSupportsMessageModality(context, "image_analysis") { + return nil, false + } + if url := imageURLFromContentPart(part); url != "" { + return map[string]any{"type": "text", "text": "Image link: " + url}, true + } + case isVideoContent(part): + if modelSupportsMessageModality(context, "video_understanding") { + return nil, false + } + if url := videoURLFromContentPart(part); url != "" { + return map[string]any{"type": "text", "text": "video URL: " + url}, true + } + case isAudioContent(part) || stringFromAny(part["type"]) == "input_audio": + if modelSupportsMessageModality(context, "audio_understanding") { + return nil, false + } + if url := audioURLFromContentPart(part); url != "" { + return map[string]any{"type": "text", "text": "audio URL: " + url}, true + } + } + return nil, false +} + +func messageContentConversionReason(part map[string]any) string { + switch { + case isImageContent(part): + return "模型不支持图像理解,已将 image_url 转为文本链接。" + case isVideoContent(part): + return "模型不支持视频理解,已将 video_url 转为文本链接。" + default: + return "模型不支持音频理解,已将音频输入转为文本链接。" + } +} + +func messageContentCapabilityPath(part map[string]any) string { + switch { + case isImageContent(part): + return "capabilities.image_analysis" + case isVideoContent(part): + return "capabilities.video_understanding" + default: + return "capabilities.audio_understanding" + } +} + +func messageContentCapabilityValue(part map[string]any, context *paramProcessContext) any { + if context == nil { + return nil + } + switch { + case isImageContent(part): + return capabilityValue(context.modelCapability, "image_analysis", "") + case isVideoContent(part): + return capabilityValue(context.modelCapability, "video_understanding", "") + default: + return capabilityValue(context.modelCapability, "audio_understanding", "") + } +} + +func modelSupportsMessageModality(context *paramProcessContext, capabilityName string) bool { + if context == nil { + return false + } + capabilities := context.modelCapability + if capabilityForType(capabilities, capabilityName) != nil { + return true + } + if capabilityForType(capabilities, "omni") != nil { + return true + } + originalTypes := stringListFromAny(capabilities["originalTypes"]) + return containsString(originalTypes, capabilityName) || containsString(originalTypes, "omni") +} + +func imageURLFromContentPart(part map[string]any) string { + return urlFromNestedContentPart(part, "image_url", "url", "imageUrl") +} + +func videoURLFromContentPart(part map[string]any) string { + return urlFromNestedContentPart(part, "video_url", "url", "videoUrl") +} + +func audioURLFromContentPart(part map[string]any) string { + if stringFromAny(part["type"]) == "input_audio" { + if audio, ok := part["input_audio"].(map[string]any); ok { + if url := firstNonEmptyString(stringFromAny(audio["data"]), stringFromAny(audio["url"])); url != "" { + return url + } + } + } + return urlFromNestedContentPart(part, "audio_url", "url", "audioUrl") +} + +func urlFromNestedContentPart(part map[string]any, keys ...string) string { + for _, key := range keys { + value := part[key] + if url := stringFromAny(value); url != "" { + return url + } + if nested, ok := value.(map[string]any); ok { + if url := stringFromAny(nested["url"]); url != "" { + return url + } + } + } + return "" +} diff --git a/apps/api/internal/runner/param_processor_test.go b/apps/api/internal/runner/param_processor_test.go index 1804ea3..ffe873d 100644 --- a/apps/api/internal/runner/param_processor_test.go +++ b/apps/api/internal/runner/param_processor_test.go @@ -6,6 +6,50 @@ import ( "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) +func TestVideoModelTypeInferenceReadsContentArray(t *testing.T) { + imageToVideo := modelTypeFromKind("videos.generations", map[string]any{ + "model": "demo-video", + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "first_frame", "image_url": map[string]any{"url": "https://example.com/frame.png"}}, + }, + }) + if imageToVideo != "image_to_video" { + t.Fatalf("image content should infer image_to_video, got %s", imageToVideo) + } + + omniVideo := modelTypeFromKind("videos.generations", map[string]any{ + "model": "demo-video", + "content": []any{ + map[string]any{"type": "text", "text": "edit it"}, + map[string]any{"type": "video_url", "role": "reference_video", "video_url": map[string]any{"url": "https://example.com/ref.mp4"}}, + }, + }) + if omniVideo != "omni_video" { + t.Fatalf("video content should infer omni_video, got %s", omniVideo) + } + + textToVideo := modelTypeFromKind("videos.generations", map[string]any{ + "model": "demo-video", + "content": []any{map[string]any{"type": "text", "text": "make a clip"}}, + }) + if textToVideo != "video_generate" { + t.Fatalf("text-only content should infer video_generate, got %s", textToVideo) + } +} + +func TestVideoContentTextContributesToTokenEstimate(t *testing.T) { + tokens := estimateRequestTokens(map[string]any{ + "model": "demo-video", + "content": []any{ + map[string]any{"type": "text", "text": "a cinematic product reveal"}, + }, + }) + if tokens <= 1 { + t.Fatalf("content text should contribute to token estimate, got %d", tokens) + } +} + func TestParamProcessorOmniFiltersUnsupportedVideoAndAudioContent(t *testing.T) { body := map[string]any{ "model": "可灵O1", @@ -123,6 +167,163 @@ func TestParamProcessorOmniCapabilityLogUsesActualCapabilityKey(t *testing.T) { t.Fatalf("expected log to reference capabilities.omni.input_audio, got %+v", result.Log.Changes) } +func TestParamProcessorChatConvertsUnsupportedMediaMessageContentToText(t *testing.T) { + body := map[string]any{ + "model": "text-only", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe these"}, + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}}, + map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}}, + map[string]any{"type": "input_audio", "input_audio": map[string]any{"data": "https://example.com/input.wav"}}, + }, + }, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "text_generate", + Capabilities: map[string]any{ + "text_generate": map[string]any{}, + "originalTypes": []any{"text_generate"}, + }, + } + + result := preprocessRequestWithLog("chat.completions", body, candidate) + messages, _ := result.Body["messages"].([]any) + if len(messages) != 1 { + t.Fatalf("expected one message, got %+v", result.Body["messages"]) + } + message, _ := messages[0].(map[string]any) + content, _ := message["content"].([]any) + if len(content) != 5 { + t.Fatalf("expected five content parts, got %+v", message["content"]) + } + expectedText := []string{ + "describe these", + "Image link: https://example.com/image.png", + "video URL: https://example.com/video.mp4", + "audio URL: https://example.com/audio.mp3", + "audio URL: https://example.com/input.wav", + } + for index, expected := range expectedText { + part, _ := content[index].(map[string]any) + if stringFromAny(part["text"]) != expected { + t.Fatalf("content[%d] text = %q, want %q; all=%+v", index, stringFromAny(part["text"]), expected, content) + } + } + if len(result.Log.Changes) != 4 { + t.Fatalf("expected four media conversion changes, got %+v", result.Log.Changes) + } + expectedCapabilityPaths := map[string]bool{ + "capabilities.image_analysis": false, + "capabilities.video_understanding": false, + "capabilities.audio_understanding": false, + } + for _, change := range result.Log.Changes { + if _, ok := expectedCapabilityPaths[change.CapabilityPath]; ok { + expectedCapabilityPaths[change.CapabilityPath] = true + } + } + for path, found := range expectedCapabilityPaths { + if !found { + t.Fatalf("expected conversion log for %s, got %+v", path, result.Log.Changes) + } + } +} + +func TestParamProcessorChatKeepsOmniMessageContent(t *testing.T) { + body := map[string]any{ + "model": "omni", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}}, + map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "https://example.com/audio.mp3"}}, + }, + }, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "text_generate", + Capabilities: map[string]any{ + "text_generate": map[string]any{}, + "omni": map[string]any{}, + "originalTypes": []any{"text_generate", "omni"}, + }, + } + + result := preprocessRequestWithLog("chat.completions", body, candidate) + if result.Log.Changed { + t.Fatalf("omni model should keep message media content unchanged, got %+v", result.Log.Changes) + } + messages, _ := result.Body["messages"].([]any) + message, _ := messages[0].(map[string]any) + content, _ := message["content"].([]any) + for _, item := range content { + part, _ := item.(map[string]any) + if stringFromAny(part["type"]) == "text" { + t.Fatalf("media content should not be converted for omni model: %+v", content) + } + } +} + +func TestParamProcessorChatConvertsOnlyUnsupportedModalities(t *testing.T) { + body := map[string]any{ + "model": "vision-only", + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "image_url", "image_url": map[string]any{"url": "https://example.com/image.png"}}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/video.mp4"}}, + }, + }, + }, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "text_generate", + Capabilities: map[string]any{ + "text_generate": map[string]any{}, + "image_analysis": map[string]any{}, + "originalTypes": []any{"text_generate", "image_analysis"}, + }, + } + + result := preprocessRequestWithLog("chat.completions", body, candidate) + messages, _ := result.Body["messages"].([]any) + message, _ := messages[0].(map[string]any) + content, _ := message["content"].([]any) + first, _ := content[0].(map[string]any) + second, _ := content[1].(map[string]any) + if stringFromAny(first["type"]) != "image_url" { + t.Fatalf("image content should be kept when image_analysis is supported: %+v", content) + } + if stringFromAny(second["text"]) != "video URL: https://example.com/video.mp4" { + t.Fatalf("video content should be converted, got %+v", second) + } + if len(result.Log.Changes) != 1 || result.Log.Changes[0].CapabilityPath != "capabilities.video_understanding" { + t.Fatalf("expected only video conversion to be logged, got %+v", result.Log.Changes) + } +} + +func TestSkipTaskParameterPreprocessingLogForTextModelTypes(t *testing.T) { + for _, modelType := range []string{"text_generate", "chat", "responses", "text"} { + if !skipTaskParameterPreprocessingLog(modelType) { + t.Fatalf("%s should skip task parameter preprocessing log", modelType) + } + } + for _, modelType := range []string{"image_generate", "image_edit", "video_generate", "omni_video"} { + if skipTaskParameterPreprocessingLog(modelType) { + t.Fatalf("%s should keep task parameter preprocessing log", modelType) + } + } +} + func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) { body := map[string]any{ "model": "Seedance", @@ -180,6 +381,222 @@ func TestParamProcessorVideoCapabilitiesNormalizeAndFilter(t *testing.T) { } } +func TestParamProcessorDowngradesReferenceImagesToFrames(t *testing.T) { + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "input_first_frame": true, + "input_first_last_frame": true, + "input_reference_generate_single": false, + "input_reference_generate_multiple": false, + }, + }, + } + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}}, + map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/last.png"}}, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Err != nil { + t.Fatalf("two image references should downgrade to first/last frames: %v", result.Err) + } + content := contentItems(result.Body["content"]) + if stringFromAny(content[1]["role"]) != "first_frame" || stringFromAny(content[2]["role"]) != "last_frame" { + t.Fatalf("expected first/last frame downgrade, got %+v", content) + } +} + +func TestParamProcessorDowngradesSingleReferenceImageToFirstFrame(t *testing.T) { + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "input_first_frame": true, + "input_first_last_frame": true, + "input_reference_generate_single": false, + "input_reference_generate_multiple": false, + }, + }, + } + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}}, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Err != nil { + t.Fatalf("single image reference should downgrade to first frame: %v", result.Err) + } + content := contentItems(result.Body["content"]) + if stringFromAny(content[1]["role"]) != "first_frame" { + t.Fatalf("expected first frame downgrade, got %+v", content) + } +} + +func TestParamProcessorRejectsUnsafeReferenceImageDowngrade(t *testing.T) { + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "input_first_frame": true, + "input_first_last_frame": false, + "input_reference_generate_single": false, + "input_reference_generate_multiple": false, + }, + }, + } + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}}, + map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/last.png"}}, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Err == nil { + t.Fatalf("two image references should be rejected when first/last frame is unsupported") + } + if len(result.Log.Changes) == 0 || result.Log.Changes[len(result.Log.Changes)-1].Action != "reject" { + t.Fatalf("expected reject preprocessing log, got %+v", result.Log.Changes) + } +} + +func TestParamProcessorRejectsVideoOrAudioReferenceDowngrade(t *testing.T) { + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "input_first_frame": true, + "input_first_last_frame": true, + "input_reference_generate_single": false, + "input_reference_generate_multiple": false, + }, + }, + } + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "content": []any{ + map[string]any{"type": "text", "text": "animate it"}, + map[string]any{"type": "image_url", "role": "reference_image", "image_url": map[string]any{"url": "https://example.com/first.png"}}, + map[string]any{"type": "video_url", "role": "reference_video", "video_url": map[string]any{"url": "https://example.com/ref.mp4"}}, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Err == nil { + t.Fatalf("video reference should be rejected instead of downgraded") + } +} + +func TestParamProcessorDurationRangeRoundsFractionalSecondsUp(t *testing.T) { + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "duration": 5.5, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "video_generate", + Capabilities: map[string]any{ + "video_generate": map[string]any{ + "duration_range": []any{3, 12}, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Body["duration"] != float64(6) && result.Body["duration"] != 6 { + t.Fatalf("fractional duration should be rounded up to default 1s step, got %+v", result.Body["duration"]) + } +} + +func TestParamProcessorDurationWithoutRangeStillRoundsUp(t *testing.T) { + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "duration": 5.2, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "video_generate", + Capabilities: map[string]any{ + "video_generate": map[string]any{}, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Body["duration"] != float64(6) && result.Body["duration"] != 6 { + t.Fatalf("duration should default to a 1s upward step without range, got %+v", result.Body["duration"]) + } +} + +func TestParamProcessorDurationRangeUsesStepCeilingAndRange(t *testing.T) { + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "duration": 6.1, + "duration_seconds": 6.1, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "duration_range": []any{5, 10}, + "duration_step": 2, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Body["duration"] != float64(7) && result.Body["duration"] != 7 { + t.Fatalf("duration should be rounded up by configured step, got %+v", result.Body["duration"]) + } + if result.Body["duration_seconds"] != result.Body["duration"] { + t.Fatalf("duration_seconds should sync with normalized duration, got %+v", result.Body) + } + + body["duration"] = 10.1 + body["duration_seconds"] = 10.1 + result = preprocessRequestWithLog("videos.generations", body, candidate) + if result.Body["duration"] != float64(10) && result.Body["duration"] != 10 { + t.Fatalf("duration should be capped by range max, got %+v", result.Body["duration"]) + } +} + +func TestParamProcessorDurationOptionsChooseNextAllowedValue(t *testing.T) { + body := map[string]any{ + "model": "Seedance", + "prompt": "animate it", + "duration": 8.1, + } + candidate := store.RuntimeModelCandidate{ + ModelType: "image_to_video", + Capabilities: map[string]any{ + "image_to_video": map[string]any{ + "duration_options": []any{4, 8, 12}, + }, + }, + } + + result := preprocessRequestWithLog("videos.generations", body, candidate) + if result.Body["duration"] != float64(12) && result.Body["duration"] != 12 { + t.Fatalf("duration should use next allowed option, got %+v", result.Body["duration"]) + } +} + func TestParamProcessorVideoGenerateLogsFirstFrameRemoval(t *testing.T) { body := map[string]any{ "model": "Seedance T2V", diff --git a/apps/api/internal/runner/param_processor_utils.go b/apps/api/internal/runner/param_processor_utils.go new file mode 100644 index 0000000..2d6ea87 --- /dev/null +++ b/apps/api/internal/runner/param_processor_utils.go @@ -0,0 +1,511 @@ +package runner + +import ( + "math" + "sort" + "strconv" + "strings" +) + +func validateAndAdjustAspectRatio(aspectRatio string, capability map[string]any, allowed []string) (string, bool) { + if !isMediaModelTypeWithAspectRatio(capability) { + return "", false + } + if ratioRange, ok := numberPair(capability["aspect_ratio_range"]); ok { + ratio, valid := aspectRatioNumber(aspectRatio) + if !valid || ratio < ratioRange[0] || ratio > ratioRange[1] { + return adjustAspectRatioToRange(aspectRatio, ratioRange[0], ratioRange[1], allowed), true + } + } + if allowed == nil { + return aspectRatio, true + } + if len(allowed) == 0 { + return "", false + } + if (aspectRatio == "adaptive" || aspectRatio == "keep_ratio") && !containsString(allowed, aspectRatio) { + return "", false + } + if containsString(allowed, aspectRatio) { + return aspectRatio, true + } + return allowed[0], true +} + +func isMediaModelTypeWithAspectRatio(capability map[string]any) bool { + return capability != nil +} + +func aspectRatioAllowed(value any, resolution string) []string { + switch typed := value.(type) { + case []any: + return stringListFromAny(typed) + case []string: + return typed + case map[string]any: + if resolution != "" { + if values := stringListFromAny(typed[resolution]); len(values) > 0 { + return values + } + } + return nil + default: + return nil + } +} + +func scopedNumberList(value any, scopes ...string) []float64 { + switch typed := value.(type) { + case []any: + out := make([]float64, 0, len(typed)) + for _, item := range typed { + if number := floatFromAny(item); number > 0 { + out = append(out, number) + } + } + return out + case []float64: + return typed + case []int: + out := make([]float64, 0, len(typed)) + for _, item := range typed { + out = append(out, float64(item)) + } + return out + case map[string]any: + for _, scope := range scopes { + if scope == "" { + continue + } + if values := scopedNumberList(typed[scope]); len(values) > 0 { + return values + } + } + for _, item := range typed { + if values := scopedNumberList(item); len(values) > 0 { + return values + } + } + } + return nil +} + +func scopedRange(value any, scopes ...string) (float64, float64, bool) { + if pair, ok := numberPair(value); ok { + return pair[0], pair[1], true + } + if typed, ok := value.(map[string]any); ok { + for _, scope := range scopes { + if scope == "" { + continue + } + if minValue, maxValue, ok := scopedRange(typed[scope]); ok { + return minValue, maxValue, true + } + } + for _, item := range typed { + if minValue, maxValue, ok := scopedRange(item); ok { + return minValue, maxValue, true + } + } + } + return 0, 0, false +} + +func durationStep(value any, scopes ...string) float64 { + if step := floatFromAny(value); step > 0 { + return step + } + if typed, ok := value.(map[string]any); ok { + for _, scope := range scopes { + if scope == "" { + continue + } + if step := durationStep(typed[scope]); step > 0 { + return step + } + } + for _, item := range typed { + if step := durationStep(item); step > 0 { + return step + } + } + } + return 0 +} + +func normalizeDurationByRange(target float64, minValue float64, maxValue float64, step float64) float64 { + if minValue > maxValue { + minValue, maxValue = maxValue, minValue + } + if step <= 0 { + step = 1 + } + clamped := math.Min(math.Max(target, minValue), maxValue) + snapped := math.Ceil(((clamped-minValue)/step)-1e-9)*step + minValue + snapped = math.Min(math.Max(snapped, minValue), maxValue) + return math.Round(snapped*1_000_000) / 1_000_000 +} + +func normalizeDurationByStep(target float64, step float64) float64 { + if step <= 0 { + step = 1 + } + snapped := math.Ceil((target/step)-1e-9) * step + return math.Round(snapped*1_000_000) / 1_000_000 +} + +func nextAllowedNumber(target float64, values []float64) float64 { + if len(values) == 0 { + return target + } + sorted := append([]float64(nil), values...) + sort.Float64s(sorted) + for _, value := range sorted { + if value >= target || math.Abs(value-target) < 1e-9 { + return value + } + } + return sorted[len(sorted)-1] +} + +func contentItems(value any) []map[string]any { + switch typed := value.(type) { + case []any: + out := make([]map[string]any, 0, len(typed)) + for _, item := range typed { + if object, ok := item.(map[string]any); ok { + out = append(out, cloneMap(object)) + } + } + return out + case []map[string]any: + out := make([]map[string]any, 0, len(typed)) + for _, item := range typed { + out = append(out, cloneMap(item)) + } + return out + default: + return nil + } +} + +func mapsToAnySlice(values []map[string]any) []any { + out := make([]any, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out +} + +func isImageContent(item map[string]any) bool { + return stringFromAny(item["type"]) == "image_url" || item["image_url"] != nil +} + +func isVideoContent(item map[string]any) bool { + return stringFromAny(item["type"]) == "video_url" || item["video_url"] != nil +} + +func isAudioContent(item map[string]any) bool { + return stringFromAny(item["type"]) == "audio_url" || item["audio_url"] != nil +} + +func capabilityForType(capabilities map[string]any, modelType string) map[string]any { + if capabilities == nil { + return nil + } + if typed, ok := capabilities[modelType].(map[string]any); ok { + return typed + } + return nil +} + +func capabilityPath(modelType string, key string) string { + modelType = strings.TrimSpace(modelType) + if modelType == "" { + return "" + } + if strings.TrimSpace(key) == "" { + return "capabilities." + modelType + } + return "capabilities." + modelType + "." + key +} + +func capabilityValue(capabilities map[string]any, modelType string, key string) any { + capability := capabilityForType(capabilities, modelType) + if capability == nil { + return nil + } + if strings.TrimSpace(key) == "" { + return cloneMap(capability) + } + return cloneAny(capability[key]) +} + +func capabilityEvidence(capabilities map[string]any, modelType string, key string) (string, any) { + return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key) +} + +func audioInputCapabilityEvidence(context *paramProcessContext, modelType string) (string, any) { + if isOmniVideoLike(context) { + path, value := omniCapabilityEvidence(context, "input_audio") + return path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")) + } + return capabilityEvidence(context.modelCapability, modelType, "input_audio") +} + +func omniCapabilityType(context *paramProcessContext) string { + if context != nil && capabilityForType(context.modelCapability, "omni_video") != nil { + return "omni_video" + } + if context != nil && capabilityForType(context.modelCapability, "omni") != nil { + return "omni" + } + return "omni_video" +} + +func omniCapabilityEvidence(context *paramProcessContext, key string) (string, any) { + modelType := omniCapabilityType(context) + var capabilities map[string]any + if context != nil { + capabilities = context.modelCapability + } + return capabilityPath(modelType, key), capabilityValue(capabilities, modelType, key) +} + +func omniCapabilityBundle(context *paramProcessContext, keys ...string) map[string]any { + modelType := omniCapabilityType(context) + var capabilities map[string]any + if context != nil { + capabilities = context.modelCapability + } + out := map[string]any{} + for _, key := range keys { + out[key] = capabilityValue(capabilities, modelType, key) + } + return out +} + +func numericField(values map[string]any, key string) (float64, bool) { + if values == nil { + return 0, false + } + if _, ok := values[key]; !ok { + return 0, false + } + return floatFromAny(values[key]), true +} + +func boolFromAny(value any) bool { + typed, _ := value.(bool) + return typed +} + +func firstNonEmptyStringValue(values map[string]any, keys ...string) string { + for _, key := range keys { + if value := stringFromAny(values[key]); value != "" { + return value + } + } + return "" +} + +func firstNonEmptyStringListFromAny(values ...any) []string { + for _, value := range values { + items := stringListFromAny(value) + if len(items) > 0 { + return items + } + } + return nil +} + +func stringListFromAny(value any) []string { + switch typed := value.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if text := strings.TrimSpace(item); text != "" { + out = append(out, text) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if text := stringFromAny(item); text != "" { + out = append(out, text) + } + } + return out + case string: + if strings.TrimSpace(typed) == "" { + return nil + } + return []string{strings.TrimSpace(typed)} + default: + return nil + } +} + +func containsString(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} + +func appendUniqueString(values *[]string, value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + for _, existing := range *values { + if existing == value { + return + } + } + *values = append(*values, value) +} + +func numberPair(value any) ([2]float64, bool) { + switch typed := value.(type) { + case []any: + if len(typed) < 2 { + return [2]float64{}, false + } + return [2]float64{floatFromAny(typed[0]), floatFromAny(typed[1])}, true + case []float64: + if len(typed) < 2 { + return [2]float64{}, false + } + return [2]float64{typed[0], typed[1]}, true + case []int: + if len(typed) < 2 { + return [2]float64{}, false + } + return [2]float64{float64(typed[0]), float64(typed[1])}, true + default: + return [2]float64{}, false + } +} + +func validAspectRatio(value string) bool { + if value == "adaptive" || value == "keep_ratio" { + return true + } + _, ok := aspectRatioNumber(value) + return ok +} + +func aspectRatioNumber(value string) (float64, bool) { + parts := strings.Split(value, ":") + if len(parts) != 2 { + return 0, false + } + width := parsePositiveFloat(parts[0]) + height := parsePositiveFloat(parts[1]) + if width <= 0 || height <= 0 { + return 0, false + } + return width / height, true +} + +func adjustAspectRatioToRange(value string, minValue float64, maxValue float64, allowed []string) string { + current, ok := aspectRatioNumber(value) + if !ok { + if len(allowed) > 0 { + return allowed[0] + } + return "1:1" + } + if len(allowed) > 0 { + closest := "" + minDiff := math.Inf(1) + for _, candidate := range allowed { + ratio, ok := aspectRatioNumber(candidate) + if !ok || ratio < minValue || ratio > maxValue { + continue + } + diff := math.Abs(ratio - current) + if diff < minDiff { + minDiff = diff + closest = candidate + } + } + if closest != "" { + return closest + } + } + if current < minValue { + return ratioString(minValue) + } + return ratioString(maxValue) +} + +func ratioString(value float64) string { + if value <= 0 { + return "1:1" + } + return strings.TrimRight(strings.TrimRight(strconv.FormatFloat(value, 'f', 6, 64), "0"), ".") + ":1" +} + +func parsePositiveFloat(value string) float64 { + for _, r := range strings.TrimSpace(value) { + if r < '0' || r > '9' { + if r != '.' { + return 0 + } + } + } + out, _ := strconv.ParseFloat(strings.TrimSpace(value), 64) + return out +} + +func isEmptyParamString(value string) bool { + normalized := strings.ToLower(strings.TrimSpace(value)) + return normalized == "null" || normalized == "undefined" +} + +func isImageResolution(modelType string, value string) bool { + return (modelType == "image_generate" || modelType == "image_edit") && containsString([]string{"1K", "2K", "4K", "8K"}, value) +} + +func isVideoResolution(modelType string, value string) bool { + return isVideoModelType(modelType) && containsString([]string{"480p", "720p", "1080p", "1440p", "2160p"}, value) +} + +func isVideoModelType(modelType string) bool { + return modelType == "video_generate" || modelType == "text_to_video" || modelType == "image_to_video" || modelType == "video_edit" || modelType == "video_reference" || modelType == "video_first_last_frame" || modelType == "omni_video" || modelType == "omni" +} + +func cloneMap(values map[string]any) map[string]any { + out := map[string]any{} + for key, value := range values { + out[key] = cloneAny(value) + } + return out +} + +func cloneAny(value any) any { + switch typed := value.(type) { + case map[string]any: + return cloneMap(typed) + case []any: + out := make([]any, 0, len(typed)) + for _, item := range typed { + out = append(out, cloneAny(item)) + } + return out + case []map[string]any: + out := make([]any, 0, len(typed)) + for _, item := range typed { + out = append(out, cloneMap(item)) + } + return out + default: + return value + } +} diff --git a/apps/api/internal/runner/param_processor_video_content.go b/apps/api/internal/runner/param_processor_video_content.go new file mode 100644 index 0000000..438ebb3 --- /dev/null +++ b/apps/api/internal/runner/param_processor_video_content.go @@ -0,0 +1,663 @@ +package runner + +import ( + "fmt" + "math" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +type contentFilterProcessor struct{} + +func (contentFilterProcessor) Name() string { return "ContentFilterProcessor" } + +func (contentFilterProcessor) ShouldProcess(params map[string]any, modelType string, context *paramProcessContext) bool { + _, ok := params["content"] + return ok +} + +func (contentFilterProcessor) Process(params map[string]any, modelType string, context *paramProcessContext) bool { + content := contentItems(params["content"]) + if len(content) == 0 { + return true + } + + if isOmniVideoLike(context) { + filtered := filterUnsupportedOmniVideoContent(content, context) + params["content"] = mapsToAnySlice(filtered) + syncVideoConvenienceFields(params, filtered, context) + return true + } + + if err := downgradeReferenceImageIfNeeded(params, content, modelType, context); err != nil { + return false + } + if modelType == "video_generate" || modelType == "text_to_video" { + next := make([]map[string]any, 0, len(content)) + for index, item := range content { + if isImageContent(item) { + reason, path, value := imageContentRemovalEvidence(item, modelType, context) + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + reason, + path, + value, + ) + continue + } + next = append(next, item) + } + content = next + } + if modelType == "image_to_video" || modelType == "omni_video" || modelType == "omni" { + if !supportsFirstAndLastFrame(context.modelCapability, modelType) { + next := make([]map[string]any, 0, len(content)) + for index, item := range content { + if stringFromAny(item["role"]) == "last_frame" { + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "模型不支持首尾帧输入,已移除 last_frame。", + capabilityPath(modelType, "input_first_last_frame"), + map[string]any{ + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), + }, + ) + continue + } + next = append(next, item) + } + content = next + deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"last_frame", "lastFrame"}, "模型不支持首尾帧输入,已移除快捷字段。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{ + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), + }) + } + } + params["content"] = mapsToAnySlice(content) + return true +} + +func imageContentRemovalEvidence(item map[string]any, modelType string, context *paramProcessContext) (string, string, any) { + role := stringFromAny(item["role"]) + switch role { + case "first_frame": + return "模型能力未开启首帧输入,已移除 first_frame。", capabilityPath(modelType, "input_first_frame"), map[string]any{ + "input_first_frame": capabilityValue(context.modelCapability, modelType, "input_first_frame"), + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + } + case "last_frame": + return "模型能力未开启尾帧或首尾帧输入,已移除 last_frame。", capabilityPath(modelType, "input_first_last_frame"), map[string]any{ + "input_last_frame": capabilityValue(context.modelCapability, modelType, "input_last_frame"), + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + "max_images_for_last_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_last_frame"), + "max_images_for_first_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_first_frame"), + "max_images_for_middle_frame": capabilityValue(context.modelCapability, modelType, "max_images_for_middle_frame"), + } + case "reference_image": + return "模型能力未开启参考图输入,已移除 reference_image。", capabilityPath(modelType, "input_reference_generate_single"), map[string]any{ + "input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"), + "input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"), + "max_images": capabilityValue(context.modelCapability, modelType, "max_images"), + } + default: + return "当前模型能力未开启图像输入,已移除 image_url。", capabilityPath(modelType, "input_first_frame"), map[string]any{ + "input_first_frame": capabilityValue(context.modelCapability, modelType, "input_first_frame"), + "input_first_last_frame": capabilityValue(context.modelCapability, modelType, "input_first_last_frame"), + "input_reference_generate_single": capabilityValue(context.modelCapability, modelType, "input_reference_generate_single"), + "input_reference_generate_multiple": capabilityValue(context.modelCapability, modelType, "input_reference_generate_multiple"), + } + } +} + +func ensureVideoContent(params map[string]any, context *paramProcessContext) { + if len(contentItems(params["content"])) > 0 { + return + } + content := make([]map[string]any, 0) + if prompt := firstNonEmptyString(stringFromAny(params["prompt"]), stringFromAny(params["input"])); prompt != "" { + content = append(content, map[string]any{"type": "text", "text": prompt}) + } + appendURL := func(kind string, role string, url string) { + url = strings.TrimSpace(url) + if url == "" { + return + } + item := map[string]any{"type": kind, "role": role} + switch kind { + case "image_url": + item["image_url"] = map[string]any{"url": url} + case "video_url": + item["video_url"] = map[string]any{"url": url} + case "audio_url": + item["audio_url"] = map[string]any{"url": url} + } + content = append(content, item) + } + + firstFrame := firstNonEmptyStringValue(params, "first_frame", "firstFrame") + appendURL("image_url", "first_frame", firstFrame) + appendURL("image_url", "last_frame", firstNonEmptyStringValue(params, "last_frame", "lastFrame")) + imageURLs := firstNonEmptyStringListFromAny(params["image"], params["images"], params["image_url"], params["imageUrl"], params["image_urls"], params["imageUrls"]) + if firstFrame == "" && len(imageURLs) > 0 { + appendURL("image_url", "first_frame", imageURLs[0]) + imageURLs = imageURLs[1:] + } + for _, url := range imageURLs { + appendURL("image_url", "reference_image", url) + } + for _, url := range firstNonEmptyStringListFromAny(params["reference_image"], params["referenceImage"]) { + appendURL("image_url", "reference_image", url) + } + for _, url := range firstNonEmptyStringListFromAny(params["video"], params["video_url"], params["videoUrl"], params["reference_video"], params["referenceVideo"]) { + appendURL("video_url", "reference_video", url) + } + for _, url := range firstNonEmptyStringListFromAny(params["audio_url"], params["audioUrl"], params["reference_audio"], params["referenceAudio"]) { + appendURL("audio_url", "reference_audio", url) + } + if len(content) > 0 { + params["content"] = mapsToAnySlice(content) + context.recordChange( + "ContentBuildProcessor", + "set", + "content", + nil, + params["content"], + "将 prompt/first_frame/reference_* 等快捷字段转换为 content 数组,后续处理器可按模型能力逐项过滤。", + "", + nil, + ) + } +} + +func effectiveModelCapability(candidate store.RuntimeModelCandidate) map[string]any { + base := cloneMap(candidate.Capabilities) + for key, value := range candidate.CapabilityOverride { + if baseChild, ok := base[key].(map[string]any); ok { + if overrideChild, ok := value.(map[string]any); ok { + base[key] = mergeMap(baseChild, overrideChild) + continue + } + } + base[key] = cloneAny(value) + } + return base +} + +func filterUnsupportedOmniVideoContent(content []map[string]any, context *paramProcessContext) []map[string]any { + capability := omniVideoCapability(context) + maxVideos := math.Inf(1) + if capability != nil { + if value, ok := numericField(capability, "max_videos"); ok { + maxVideos = value + } + } + maxAudios := 0.0 + if capability != nil { + if value, ok := numericField(capability, "max_audios"); ok { + maxAudios = value + } else if supportsOmniAudioReference(context) { + maxAudios = math.Inf(1) + } + } + + videoCount := 0.0 + audioCount := 0.0 + out := make([]map[string]any, 0, len(content)) + for index, item := range content { + if isVideoContent(item) { + if !supportsOmniVideoReference(item, capability) { + path, value := omniCapabilityEvidence(context, "supported_modes") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "视频参考类型不在 omni_video.supported_modes 允许范围内。", + path, + value, + ) + continue + } + if videoCount >= maxVideos { + path, value := omniCapabilityEvidence(context, "max_videos") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "视频参考数量超过 omni_video.max_videos 限制。", + path, + value, + ) + continue + } + videoCount++ + out = append(out, item) + continue + } + if isAudioContent(item) { + if !supportsOmniAudioReference(context) { + path, value := omniCapabilityEvidence(context, "input_audio") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "模型能力不支持音频参考,已移除 audio_url。", + path, + mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios")), + ) + continue + } + if audioCount >= maxAudios { + path, value := omniCapabilityEvidence(context, "max_audios") + context.recordChange( + "ContentFilterProcessor", + "remove", + fmt.Sprintf("content[%d]", index), + item, + nil, + "音频参考数量超过 omni_video.max_audios 限制。", + path, + value, + ) + continue + } + audioCount++ + out = append(out, item) + continue + } + out = append(out, item) + } + return out +} + +func isOmniVideoLike(context *paramProcessContext) bool { + modelType := strings.TrimSpace(context.candidate.ModelType) + return modelType == "omni_video" || + modelType == "omni" || + context.modelCapability["omni_video"] != nil || + context.modelCapability["omni"] != nil +} + +func omniVideoCapability(context *paramProcessContext) map[string]any { + if capability := capabilityForType(context.modelCapability, "omni_video"); capability != nil { + return capability + } + return capabilityForType(context.modelCapability, "omni") +} + +func supportsOmniAudioReference(context *paramProcessContext) bool { + capability := omniVideoCapability(context) + return capability != nil && (boolFromAny(capability["input_audio"]) || floatFromAny(capability["max_audios"]) > 0) +} + +func supportsOmniVideoReference(item map[string]any, capability map[string]any) bool { + if capability == nil { + return true + } + if value, ok := numericField(capability, "max_videos"); ok && value == 0 { + return false + } + supportedModes := stringListFromAny(capability["supported_modes"]) + supportsReference := containsString(supportedModes, "video_reference") + supportsEdit := containsString(supportedModes, "video_edit") + video, _ := item["video_url"].(map[string]any) + referType := stringFromAny(video["refer_type"]) + isEditVideo := stringFromAny(item["role"]) == "video_base" || referType == "base" + isReferenceVideo := stringFromAny(item["role"]) == "video_feature" || + stringFromAny(item["role"]) == "reference_video" || + referType == "feature" + if isEditVideo { + return supportsEdit + } + if isReferenceVideo { + return supportsReference + } + return supportsReference || supportsEdit +} + +func downgradeReferenceImageIfNeeded(params map[string]any, content []map[string]any, modelType string, context *paramProcessContext) error { + if !isVideoModelType(modelType) { + return nil + } + if supportsReferenceImage(context.modelCapability, modelType) { + return nil + } + + imageIndexes := make([]int, 0) + referenceIndexes := make([]int, 0) + hasVideoOrAudioReference := false + for index, item := range content { + if isVideoContent(item) || isAudioContent(item) { + hasVideoOrAudioReference = true + continue + } + if !isImageContent(item) { + continue + } + imageIndexes = append(imageIndexes, index) + role := stringFromAny(item["role"]) + if role == "" || role == "reference_image" { + referenceIndexes = append(referenceIndexes, index) + } + } + if len(referenceIndexes) == 0 { + return nil + } + + evidence := referenceImageDowngradeCapabilityEvidence(context.modelCapability, modelType) + if hasVideoOrAudioReference { + context.reject( + "ContentFilterProcessor", + "content", + content, + "当前模型不支持多模态参考,不能将视频或音频参考降级为首尾帧,请移除视频/音频参考或选择支持多模态参考的模型。", + evidence.path, + evidence.value, + ) + return context.err + } + if len(imageIndexes) > 2 { + context.reject( + "ContentFilterProcessor", + "content", + content, + "当前模型不支持多参考图输入,最多只允许 2 张图片降级为首尾帧。", + evidence.path, + evidence.value, + ) + return context.err + } + if len(imageIndexes) == 2 && !supportsFirstAndLastFrame(context.modelCapability, modelType) { + context.reject( + "ContentFilterProcessor", + "content", + content, + "当前模型不支持首尾帧输入,不能将 2 张参考图降级为首尾帧。", + evidence.path, + evidence.value, + ) + return context.err + } + if len(imageIndexes) == 1 && !supportsFirstFrame(context.modelCapability, modelType) { + context.reject( + "ContentFilterProcessor", + "content", + content, + "当前模型不支持首帧输入,不能将参考图降级为首帧。", + evidence.path, + evidence.value, + ) + return context.err + } + + if len(imageIndexes) == 1 { + adjustImageContentRole(content, imageIndexes[0], "first_frame", context, modelType, "模型不支持 reference_image,且只有 1 张图片,已降级为 first_frame。") + appendParamWarning(params, "reference_image is unsupported by the selected model and was downgraded to first_frame") + return nil + } + + firstIndex, lastIndex := firstLastFrameIndexes(content, imageIndexes) + adjustImageContentRole(content, firstIndex, "first_frame", context, modelType, "模型不支持 reference_image,2 张图片已降级为首尾帧的 first_frame。") + adjustImageContentRole(content, lastIndex, "last_frame", context, modelType, "模型不支持 reference_image,2 张图片已降级为首尾帧的 last_frame。") + appendParamWarning(params, "reference_image is unsupported by the selected model and was downgraded to first/last frame") + return nil +} + +type capabilityEvidenceValue struct { + path string + value any +} + +func referenceImageDowngradeCapabilityEvidence(modelCapability map[string]any, modelType string) capabilityEvidenceValue { + actualType, capability := firstVideoInputCapability(modelCapability, modelType) + if actualType == "" { + actualType = modelType + } + value := map[string]any{} + if capability != nil { + for _, key := range []string{ + "input_reference_generate_single", + "input_reference_generate_multiple", + "max_images", + "input_first_frame", + "input_first_last_frame", + "max_images_for_last_frame", + } { + value[key] = cloneAny(capability[key]) + } + } + return capabilityEvidenceValue{path: capabilityPath(actualType, ""), value: value} +} + +func adjustImageContentRole(content []map[string]any, index int, role string, context *paramProcessContext, modelType string, reason string) { + if index < 0 || index >= len(content) { + return + } + item := content[index] + if stringFromAny(item["role"]) == role { + return + } + before := cloneMap(item) + item["role"] = role + context.recordChange( + "ContentFilterProcessor", + "adjust", + fmt.Sprintf("content[%d].role", index), + before, + item, + reason, + capabilityPath(modelType, "input_reference_generate_single"), + referenceImageDowngradeCapabilityEvidence(context.modelCapability, modelType).value, + ) +} + +func firstLastFrameIndexes(content []map[string]any, imageIndexes []int) (int, int) { + firstIndex := -1 + lastIndex := -1 + for _, index := range imageIndexes { + switch stringFromAny(content[index]["role"]) { + case "first_frame": + if firstIndex == -1 { + firstIndex = index + } + case "last_frame": + if lastIndex == -1 { + lastIndex = index + } + } + } + if firstIndex == -1 && lastIndex == -1 { + return imageIndexes[0], imageIndexes[1] + } + if firstIndex == -1 { + for _, index := range imageIndexes { + if index != lastIndex { + firstIndex = index + break + } + } + } + if lastIndex == -1 { + for _, index := range imageIndexes { + if index != firstIndex { + lastIndex = index + break + } + } + } + if firstIndex == lastIndex { + return imageIndexes[0], imageIndexes[1] + } + return firstIndex, lastIndex +} + +type videoInputCapabilityValue struct { + modelType string + capability map[string]any +} + +func firstVideoInputCapability(modelCapability map[string]any, modelType string) (string, map[string]any) { + for _, candidate := range videoInputCapabilityCandidates(modelCapability, modelType) { + return candidate.modelType, candidate.capability + } + return "", nil +} + +func videoInputCapabilityCandidates(modelCapability map[string]any, modelType string) []videoInputCapabilityValue { + keys := []string{modelType, "image_to_video", "video_first_last_frame"} + if modelType == "omni_video" || modelType == "omni" { + keys = append(keys, "omni_video", "omni") + } + seen := map[string]bool{} + out := make([]videoInputCapabilityValue, 0, len(keys)) + for _, key := range keys { + key = strings.TrimSpace(key) + if key == "" || seen[key] { + continue + } + seen[key] = true + if capability := capabilityForType(modelCapability, key); capability != nil { + out = append(out, videoInputCapabilityValue{modelType: key, capability: capability}) + } + } + return out +} + +func supportsReferenceImage(modelCapability map[string]any, modelType string) bool { + candidates := videoInputCapabilityCandidates(modelCapability, modelType) + if len(candidates) == 0 { + return true + } + for _, candidate := range candidates { + capability := candidate.capability + _, hasSingle := capability["input_reference_generate_single"] + _, hasMultiple := capability["input_reference_generate_multiple"] + if hasSingle || hasMultiple { + if boolFromAny(capability["input_reference_generate_single"]) || boolFromAny(capability["input_reference_generate_multiple"]) { + return true + } + continue + } + if value, ok := numericField(capability, "max_images"); ok { + if value > 1 { + return true + } + continue + } + } + return false +} + +func supportsFirstFrame(modelCapability map[string]any, modelType string) bool { + for _, candidate := range videoInputCapabilityCandidates(modelCapability, modelType) { + capability := candidate.capability + if boolFromAny(capability["input_first_frame"]) || + boolFromAny(capability["input_first_last_frame"]) || + floatFromAny(capability["max_images_for_first_frame"]) > 0 || + floatFromAny(capability["max_images_for_last_frame"]) > 0 { + return true + } + } + return false +} + +func supportsFirstAndLastFrame(modelCapability map[string]any, modelType string) bool { + for _, candidate := range videoInputCapabilityCandidates(modelCapability, modelType) { + capability := candidate.capability + if boolFromAny(capability["input_first_last_frame"]) || floatFromAny(capability["max_images_for_last_frame"]) > 0 { + return true + } + } + return false +} + +func videoModeKey(params map[string]any) string { + content := contentItems(params["content"]) + hasFirstFrame := false + hasLastFrame := false + for _, item := range content { + switch stringFromAny(item["role"]) { + case "first_frame": + hasFirstFrame = true + case "last_frame": + hasLastFrame = true + } + } + switch { + case hasFirstFrame && hasLastFrame: + return "input_first_last_frame" + case hasFirstFrame: + return "input_first_frame" + case hasLastFrame: + return "input_last_frame" + default: + return "" + } +} + +func syncDurationSeconds(params map[string]any) { + if params["duration_seconds"] != nil { + params["duration_seconds"] = params["duration"] + } +} + +func syncVideoConvenienceFields(params map[string]any, content []map[string]any, context *paramProcessContext) { + hasVideo := false + hasAudio := false + for _, item := range content { + hasVideo = hasVideo || isVideoContent(item) + hasAudio = hasAudio || isAudioContent(item) + } + if !hasVideo { + path, value := omniCapabilityEvidence(context, "supported_modes") + deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"video", "video_url", "videoUrl", "reference_video", "referenceVideo"}, "对应视频 content 已被模型能力过滤,移除视频参考快捷字段。", path, value) + } + if !hasAudio { + path, value := omniCapabilityEvidence(context, "input_audio") + deleteFieldsWithLog(params, context, "ContentFilterProcessor", []string{"audio_url", "audioUrl", "reference_audio", "referenceAudio"}, "对应音频 content 已被模型能力过滤,移除音频参考快捷字段。", path, mergeMetrics(map[string]any{"input_audio": value}, omniCapabilityBundle(context, "max_audios"))) + } +} + +func deleteFieldsWithLog(params map[string]any, context *paramProcessContext, processor string, keys []string, reason string, capabilityPath string, capabilityValue any) { + for _, key := range keys { + if before, ok := params[key]; ok { + delete(params, key) + context.recordChange(processor, "remove", key, before, nil, reason, capabilityPath, capabilityValue) + } + } +} + +func appendParamWarning(params map[string]any, warning string) { + warnings, _ := params["_param_warnings"].([]any) + for _, item := range warnings { + if stringFromAny(item) == warning { + return + } + } + params["_param_warnings"] = append(warnings, warning) +} + +func filterContent(content []map[string]any, keep func(map[string]any) bool) []map[string]any { + out := make([]map[string]any, 0, len(content)) + for _, item := range content { + if keep(item) { + out = append(out, item) + } + } + return out +} diff --git a/apps/api/internal/runner/pricing.go b/apps/api/internal/runner/pricing.go index 0eaf684..5c8ac6f 100644 --- a/apps/api/internal/runner/pricing.go +++ b/apps/api/internal/runner/pricing.go @@ -11,8 +11,10 @@ import ( ) type EstimateResult struct { - Items []any `json:"items"` - Resolver string `json:"resolver"` + Items []any `json:"items"` + Resolver string `json:"resolver"` + TotalAmount float64 `json:"totalAmount"` + Currency string `json:"currency"` } func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) { @@ -23,9 +25,12 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body } candidate := candidates[0] body = preprocessRequest(kind, body, candidate) + items := s.estimatedBillings(ctx, user, kind, body, candidate) return EstimateResult{ - Items: s.estimatedBillings(ctx, user, kind, body, candidate), - Resolver: "effective-pricing-v1", + Items: items, + Resolver: "effective-pricing-v1", + TotalAmount: totalBillingAmount(items), + Currency: billingCurrency(items), }, nil } @@ -60,10 +65,7 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo billingLine(candidate, "text_output", "1k_tokens", outputTokens, outputAmount, discount, simulated), } } - count := int(floatFromAny(body["n"])) - if count <= 0 { - count = 1 - } + count := requestOutputCount(body) resource := "image" unit := "image" baseKey := "imageBase" @@ -73,8 +75,24 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo } if kind == "videos.generations" { resource = "video" - unit = "video" + unit = "5s_video" baseKey = "videoBase" + duration := requestDurationSeconds(body) + durationUnits := math.Max(1, math.Ceil(duration/5)) + amount := float64(count) * + durationUnits * + resourcePrice(config, resource, baseKey, "basePrice") * + resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * + resourceWeight(config, resource, "audioWeights", boolWeightKey(boolishValue(body["audio"]))) * + resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) * + resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body))) * + discount + return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{ + "count": count, + "durationSeconds": duration, + "durationUnit": "5s", + "durationUnitCount": durationUnits, + })} } amount := float64(count) * resourcePrice(config, resource, baseKey, "basePrice") * resourceWeight(config, resource, "qualityWeights", stringFromMap(body, "quality")) * resourceWeight(config, resource, "sizeWeights", stringFromMap(body, "size")) * resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * discount return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)} @@ -109,17 +127,23 @@ func effectiveDiscount(ctx context.Context, db *store.Store, user *auth.User, ca if discount <= 0 { discount = 1 } - if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil { - groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"]) - if groupDiscount > 0 { - discount *= groupDiscount + if db != nil { + if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil { + groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"]) + if groupDiscount > 0 { + discount *= groupDiscount + } } } return discount } func billingLine(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool) map[string]any { - return map[string]any{ + return billingLineWithDetails(candidate, resourceType, unit, quantity, amount, discount, simulated, nil) +} + +func billingLineWithDetails(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool, details map[string]any) map[string]any { + line := map[string]any{ "model": candidate.ModelName, "modelAlias": candidate.ModelAlias, "provider": candidate.Provider, @@ -133,6 +157,10 @@ func billingLine(candidate store.RuntimeModelCandidate, resourceType string, uni "discountFactor": discount, "simulated": simulated, } + for key, value := range details { + line[key] = value + } + return line } func price(config map[string]any, key string) float64 { @@ -177,7 +205,16 @@ func weighted(config map[string]any, key string, name string) float64 { } func resourceWeight(config map[string]any, resource string, key string, name string) float64 { - if value := weighted(config, key, name); value != 1 { + keys := weightKeyAliases(key) + names := weightValueAliases(key, name) + for _, candidateKey := range keys { + for _, candidateName := range names { + if value := weighted(config, candidateKey, candidateName); value != 1 { + return value + } + } + } + if value := dynamicWeight(config["dynamicWeight"], keys, names); value != 1 { return value } if strings.TrimSpace(name) == "" { @@ -187,19 +224,201 @@ func resourceWeight(config map[string]any, resource string, key string, name str if len(resourceConfig) == 0 && resource == "image_edit" { resourceConfig, _ = config["image"].(map[string]any) } - if weights, ok := resourceConfig["dynamicWeight"].(map[string]any); ok { - if value := floatFromAny(weights[name]); value > 0 { - return value + if value := dynamicWeight(resourceConfig["dynamicWeight"], keys, names); value != 1 { + return value + } + for _, candidateKey := range keys { + if weights, ok := resourceConfig[candidateKey].(map[string]any); ok { + for _, candidateName := range names { + if value := floatFromAny(weights[candidateName]); value > 0 { + return value + } + } } } - if weights, ok := resourceConfig[key].(map[string]any); ok { - if value := floatFromAny(weights[name]); value > 0 { + return 1 +} + +func dynamicWeight(value any, keys []string, names []string) float64 { + if len(names) == 0 { + return 1 + } + weights, _ := value.(map[string]any) + if len(weights) == 0 { + return 1 + } + for _, name := range names { + if direct := floatFromAny(weights[name]); direct > 0 { + return direct + } + } + for _, key := range keys { + if nested, ok := weights[key].(map[string]any); ok { + for _, name := range names { + if nestedValue := floatFromAny(nested[name]); nestedValue > 0 { + return nestedValue + } + } + } + } + return 1 +} + +func weightKeyAliases(key string) []string { + switch key { + case "qualityWeights": + return []string{"qualityWeights", "qualityFactors"} + case "resolutionWeights": + return []string{"resolutionWeights", "resolutionFactors"} + case "audioWeights": + return []string{"audioWeights", "audioFactors"} + case "referenceVideoWeights": + return []string{"referenceVideoWeights", "referenceVideoFactors"} + case "voiceSpecifiedWeights": + return []string{"voiceSpecifiedWeights", "voiceSpecifiedFactors"} + default: + return []string{key} + } +} + +func weightValueAliases(key string, name string) []string { + name = strings.TrimSpace(name) + if name == "" { + return nil + } + switch key { + case "audioWeights": + return []string{name, "audio-" + name} + case "referenceVideoWeights": + return []string{name, "reference-video-" + name} + case "voiceSpecifiedWeights": + return []string{name, "voice-specified-" + name} + default: + return []string{name} + } +} + +func requestOutputCount(body map[string]any) int { + for _, key := range []string{"n", "count", "batch_size", "batchSize"} { + if value := int(math.Ceil(floatFromAny(body[key]))); value > 0 { return value } } return 1 } +func requestDurationSeconds(body map[string]any) float64 { + for _, key := range []string{"duration", "durationSeconds", "duration_seconds"} { + if value := floatFromAny(body[key]); value > 0 { + return value + } + } + for _, value := range body { + items, ok := value.([]any) + if !ok || len(items) == 0 { + continue + } + total := 0.0 + allDurationItems := true + for _, item := range items { + record, ok := item.(map[string]any) + if !ok { + allDurationItems = false + break + } + duration := floatFromAny(record["duration"]) + if duration <= 0 { + allDurationItems = false + break + } + total += duration + } + if allDurationItems && total > 0 { + return total + } + } + return 5 +} + +func requestHasReferenceVideo(body map[string]any) bool { + if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) { + return true + } + if firstNonEmptyStringValue(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") != "" { + return true + } + content, _ := body["content"].([]any) + for _, item := range content { + record, _ := item.(map[string]any) + if len(record) == 0 { + continue + } + itemType := strings.TrimSpace(stringFromAny(record["type"])) + role := strings.TrimSpace(stringFromAny(record["role"])) + if itemType == "video_url" || role == "video_feature" || role == "video_base" || role == "reference_video" { + return true + } + } + return false +} + +func requestHasVoiceID(body map[string]any) bool { + return boolishValue(body["audio"]) && firstNonEmptyStringValue(body, "voice_id", "voiceId") != "" +} + +func boolWeightKey(value bool) string { + if value { + return "true" + } + return "false" +} + +func boolishValue(value any) bool { + switch typed := value.(type) { + case bool: + return typed + case string: + switch strings.ToLower(strings.TrimSpace(typed)) { + case "true", "1", "yes", "on": + return true + default: + return false + } + case int: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + default: + return false + } +} + +func hasNonEmptyArray(value any) bool { + items, ok := value.([]any) + return ok && len(items) > 0 +} + +func totalBillingAmount(items []any) float64 { + total := 0.0 + for _, raw := range items { + line, _ := raw.(map[string]any) + total += floatFromAny(line["amount"]) + } + return roundPrice(total) +} + +func billingCurrency(items []any) string { + for _, raw := range items { + line, _ := raw.(map[string]any) + if currency := stringFromAny(line["currency"]); currency != "" { + return currency + } + } + return "resource" +} + func firstNonEmptyString(values ...string) string { for _, value := range values { if strings.TrimSpace(value) != "" { diff --git a/apps/api/internal/runner/pricing_test.go b/apps/api/internal/runner/pricing_test.go new file mode 100644 index 0000000..45280ea --- /dev/null +++ b/apps/api/internal/runner/pricing_test.go @@ -0,0 +1,132 @@ +package runner + +import ( + "context" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestImageBillingEstimateUsesCountResolutionAndQuality(t *testing.T) { + service := &Service{} + candidate := store.RuntimeModelCandidate{ + ModelName: "image-model", + BaseBillingConfig: map[string]any{ + "image": map[string]any{ + "basePrice": 10, + "dynamicWeight": map[string]any{ + "resolutionFactors": map[string]any{"2K": 1.5}, + "qualityFactors": map[string]any{"high": 1.5}, + }, + }, + }, + } + + items := service.billings(context.Background(), nil, "images.generations", map[string]any{ + "count": 2, + "quality": "high", + "resolution": "2K", + }, candidate, clients.Response{}, true) + + line := firstBillingLine(t, items) + if got, want := floatFromAny(line["amount"]), 45.0; got != want { + t.Fatalf("image estimated amount = %v, want %v", got, want) + } + if got, want := line["quantity"], 2; got != want { + t.Fatalf("image quantity = %v, want %v", got, want) + } +} + +func TestVideoBillingEstimateUsesFiveSecondUnitsAndDynamicWeights(t *testing.T) { + service := &Service{} + candidate := store.RuntimeModelCandidate{ + ModelName: "video-model", + BaseBillingConfig: map[string]any{ + "video": map[string]any{ + "basePrice": 100, + "dynamicWeight": map[string]any{ + "resolutionWeights": map[string]any{"1080p": 1.5}, + "audioWeights": map[string]any{"true": 2}, + "referenceVideoWeights": map[string]any{"true": 1.5}, + "voiceSpecifiedWeights": map[string]any{"true": 1.2}, + "unusedCompatibilityField": map[string]any{"true": 99}, + }, + }, + }, + } + + items := service.billings(context.Background(), nil, "videos.generations", map[string]any{ + "audio": true, + "duration": 12, + "resolution": "1080p", + "voice_id": "voice-a", + "content": []any{ + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/reference.mp4"}}, + }, + }, candidate, clients.Response{}, true) + + line := firstBillingLine(t, items) + if got, want := floatFromAny(line["amount"]), 1620.0; got != want { + t.Fatalf("video estimated amount = %v, want %v", got, want) + } + if got, want := floatFromAny(line["durationUnitCount"]), 3.0; got != want { + t.Fatalf("video duration units = %v, want %v", got, want) + } + if got, want := line["quantity"], 3; got != want { + t.Fatalf("video quantity = %v, want %v", got, want) + } +} + +func TestVideoBillingEstimateSupportsServerMainStyleDynamicKeys(t *testing.T) { + service := &Service{} + candidate := store.RuntimeModelCandidate{ + ModelName: "legacy-video-model", + BaseBillingConfig: map[string]any{ + "videoBase": 100, + "video": map[string]any{ + "dynamicWeight": map[string]any{ + "720p": 1.25, + "audio-true": 2, + "reference-video-true": 1.5, + }, + }, + }, + } + + items := service.billings(context.Background(), nil, "videos.generations", map[string]any{ + "audio": "true", + "duration": 5, + "resolution": "720p", + "video_list": []any{map[string]any{"url": "https://example.com/reference.mp4"}}, + }, candidate, clients.Response{}, true) + + line := firstBillingLine(t, items) + if got, want := floatFromAny(line["amount"]), 375.0; got != want { + t.Fatalf("legacy video estimated amount = %v, want %v", got, want) + } +} + +func TestVideoDurationEstimateSumsMultiShotDurations(t *testing.T) { + duration := requestDurationSeconds(map[string]any{ + "multi_prompt": []any{ + map[string]any{"prompt": "shot 1", "duration": 3}, + map[string]any{"prompt": "shot 2", "duration": 7}, + }, + }) + if duration != 10 { + t.Fatalf("multi-shot duration = %v, want 10", duration) + } +} + +func firstBillingLine(t *testing.T, items []any) map[string]any { + t.Helper() + if len(items) != 1 { + t.Fatalf("items length = %d, want 1: %+v", len(items), items) + } + line, ok := items[0].(map[string]any) + if !ok { + t.Fatalf("item type = %T, want map[string]any", items[0]) + } + return line +} diff --git a/apps/api/internal/runner/recording.go b/apps/api/internal/runner/recording.go index cf6078c..37facb4 100644 --- a/apps/api/internal/runner/recording.go +++ b/apps/api/internal/runner/recording.go @@ -86,7 +86,7 @@ func taskMetrics(task store.GatewayTask, user *auth.User, body map[string]any, c copyIfPresent(metrics, body, "style") case "videos.generations": metrics["hasReferenceImage"] = imageInputCount(body) > 0 - metrics["hasReferenceVideo"] = hasAnyString(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") + metrics["hasReferenceVideo"] = hasAnyString(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") || hasVideoContent(body) copyIfPresent(metrics, body, "duration") copyIfPresent(metrics, body, "resolution") copyIfPresent(metrics, body, "size") @@ -303,9 +303,23 @@ func imageInputCount(body map[string]any) int { count += len(values) } } + for _, item := range contentItems(body["content"]) { + if isImageContent(item) { + count++ + } + } return count } +func hasVideoContent(body map[string]any) bool { + for _, item := range contentItems(body["content"]) { + if isVideoContent(item) { + return true + } + } + return false +} + func hasAnyString(body map[string]any, keys ...string) bool { for _, key := range keys { if stringFromMap(body, key) != "" { diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index ade0144..8fb2c9e 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -104,6 +104,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut firstCandidateBody = preprocessing.Body firstPreprocessing = preprocessing.Log normalizedModelType = candidates[0].ModelType + if preprocessing.Err != nil { + clientErr := parameterPreprocessClientError(preprocessing.Err) + if logErr := s.recordTaskParameterPreprocessing(ctx, task.ID, "", 0, candidates[0], firstPreprocessing); logErr != nil { + return Result{}, logErr + } + failed, finishErr := s.failTask(ctx, task.ID, clients.ErrorCode(clientErr), clientErr.Error(), task.RunMode == "simulation", clientErr, parameterPreprocessingMetrics(firstPreprocessing)) + if finishErr != nil { + return Result{}, finishErr + } + return Result{Task: failed, Output: failed.Result}, clientErr + } if err := s.store.MarkTaskRunning(ctx, task.ID, candidates[0].ModelType, firstCandidateBody); err != nil { return Result{}, err } @@ -149,6 +160,10 @@ candidatesLoop: preprocessing := preprocessRequestWithLog(task.Kind, body, candidate) preprocessingLog := preprocessing.Log lastPreprocessing = &preprocessingLog + if preprocessing.Err != nil { + lastErr = parameterPreprocessClientError(preprocessing.Err) + break candidatesLoop + } candidateBody := preprocessing.Body response, err := s.runCandidate(ctx, task, user, candidateBody, preprocessing.Log, candidate, nextAttemptNo, onDelta) if err == nil { @@ -481,7 +496,7 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated) return clients.Response{}, err } - uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result) + uploadedResult, err := s.uploadGeneratedAssets(ctx, task.ID, task.Kind, response.Result) if err != nil { metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), parameterPreprocessingMetrics(preprocessing), map[string]any{ "error": err.Error(), @@ -531,6 +546,9 @@ func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user } func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID string, attemptID string, attemptNo int, candidate store.RuntimeModelCandidate, log parameterPreprocessingLog) error { + if skipTaskParameterPreprocessingLog(log.ModelType) { + return nil + } _, err := s.store.CreateTaskParamPreprocessingLog(ctx, store.CreateTaskParamPreprocessingLogInput{ TaskID: taskID, AttemptID: attemptID, @@ -549,6 +567,15 @@ func (s *Service) recordTaskParameterPreprocessing(ctx context.Context, taskID s return err } +func skipTaskParameterPreprocessingLog(modelType string) bool { + switch strings.TrimSpace(modelType) { + case "text_generate", "chat", "responses", "text": + return true + default: + return false + } +} + func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client { if simulated { return s.clients["simulation"] @@ -687,7 +714,7 @@ func requestedModelTypeFromBody(body map[string]any) string { func isKnownModelType(value string) bool { switch value { - case "text_generate", "image_generate", "image_edit", "video_generate", "image_to_video", "text_to_video", "video_edit", "omni_video", "omni": + case "text_generate", "image_generate", "image_edit", "video_generate", "image_to_video", "text_to_video", "video_edit", "video_reference", "video_first_last_frame", "omni_video", "omni": return true default: return false @@ -706,6 +733,11 @@ func videoRequestHasReferenceImage(body map[string]any) bool { return true } } + for _, item := range contentItems(body["content"]) { + if isImageContent(item) { + return true + } + } return false } @@ -851,3 +883,15 @@ func validateRequest(kind string, body map[string]any) error { } return nil } + +func parameterPreprocessClientError(err error) *clients.ClientError { + if err == nil { + return nil + } + return &clients.ClientError{ + Code: "invalid_parameter", + Message: err.Error(), + StatusCode: 400, + Retryable: false, + } +} diff --git a/apps/api/internal/runner/upload.go b/apps/api/internal/runner/upload.go index 2972137..1f57f82 100644 --- a/apps/api/internal/runner/upload.go +++ b/apps/api/internal/runner/upload.go @@ -3,24 +3,115 @@ package runner import ( "bytes" "context" + "crypto/rand" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" + "io" + "mime" "mime/multipart" "net/http" + "net/textproto" + "net/url" + "os" + "path/filepath" "strings" + "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) -func (s *Service) uploadGeneratedAssets(ctx context.Context, result map[string]any) (map[string]any, error) { - if s.cfg.ServerMainBaseURL == "" || s.cfg.ServerMainInternalToken == "" { - return result, nil +const defaultServerMainOpenAPIUploadURL = "http://127.0.0.1:3001/v1/files/upload" +const maxGeneratedAssetFetchBytes = 256 << 20 + +const ( + localStaticGeneratedPathPrefix = "/static/generated/" + localStaticUploadedPathPrefix = "/static/uploaded/" +) + +type FileUploadPayload struct { + ContentType string + FileName string + Scene string + Source string + Bytes []byte +} + +type generatedAssetUploadPolicy struct { + UploadInlineMedia bool + UploadURLMedia bool + StoreInlineMediaLocally bool +} + +type generatedAssetDecision struct { + Inline *generatedInlineAsset + URL *generatedURLAsset + StripKeys []string +} + +type generatedInlineAsset struct { + Bytes []byte + ContentType string + Kind string + SourceKey string +} + +type generatedURLAsset struct { + URL string + ContentType string + Kind string + SourceKey string +} + +func defaultGeneratedAssetUploadPolicy() generatedAssetUploadPolicy { + return generatedAssetUploadPolicy{ + UploadInlineMedia: true, + UploadURLMedia: false, } +} + +func (s *Service) uploadGeneratedAssets(ctx context.Context, taskID string, taskKind string, result map[string]any) (map[string]any, error) { data, _ := result["data"].([]any) if len(data) == 0 { return result, nil } + policy, err := s.generatedAssetUploadPolicy(ctx) + if err != nil { + return nil, &clients.ClientError{Code: "upload_config_failed", Message: err.Error(), Retryable: true} + } + decisions := make([]generatedAssetDecision, len(data)) + needsUpload := false + changed := false + for index, rawItem := range data { + item, _ := rawItem.(map[string]any) + if item == nil { + continue + } + decision, err := generatedAssetDecisionForItem(taskKind, item, policy) + if err != nil { + return nil, err + } + decisions[index] = decision + if decision.Inline != nil || decision.URL != nil { + needsUpload = true + } + if len(decision.StripKeys) > 0 { + changed = true + } + } + if !needsUpload && !changed { + return result, nil + } + var channels []store.FileStorageChannel + if needsUpload && generatedAssetNeedsChannelLookup(policy, decisions) { + channels, err = s.activeFileStorageChannels(ctx, store.FileStorageSceneImageResult) + if err != nil { + return nil, &clients.ClientError{Code: "upload_config_failed", Message: err.Error(), Retryable: true} + } + } next := map[string]any{} for key, value := range result { next[key] = value @@ -32,24 +123,55 @@ func (s *Service) uploadGeneratedAssets(ctx context.Context, result map[string]a nextData = append(nextData, rawItem) continue } - b64 := stringFromMap(item, "b64_json") - if b64 == "" { - nextData = append(nextData, rawItem) - continue - } - upload, err := s.uploadBase64Image(ctx, b64, index) - if err != nil { - return nil, err - } + decision := decisions[index] merged := map[string]any{} for key, value := range item { - if key != "b64_json" { - merged[key] = value - } + merged[key] = value } - merged["upload"] = upload - if urlValue, ok := upload["url"].(string); ok && urlValue != "" { - merged["url"] = urlValue + for _, key := range decision.StripKeys { + delete(merged, key) + } + if decision.Inline != nil || decision.URL != nil { + var upload map[string]any + var sourceKey string + var strategy string + var kind string + var contentType string + var err error + if decision.Inline != nil { + upload, contentType, kind, strategy, err = s.uploadGeneratedAsset(ctx, taskID, decision.Inline, index, channels, policy.StoreInlineMediaLocally) + sourceKey = decision.Inline.SourceKey + } else { + upload, contentType, kind, strategy, err = s.uploadGeneratedURLAsset(ctx, taskID, decision.URL, index, channels) + sourceKey = decision.URL.SourceKey + } + if err != nil { + return nil, err + } + merged["upload"] = upload + merged["assetStorage"] = map[string]any{ + "scene": store.FileStorageSceneImageResult, + "source": sourceKey, + "strategy": strategy, + } + if contentType != "" { + merged["assetStorage"].(map[string]any)["contentType"] = contentType + } + if urlValue := stringFromAny(upload["url"]); urlValue != "" { + merged["url"] = urlValue + if kind == "video" { + merged["video_url"] = urlValue + } + if kind == "image" { + merged["image_url"] = urlValue + } + } + if kind != "" && stringFromAny(merged["type"]) == "" { + merged["type"] = kind + } + if contentType != "" && stringFromAny(merged["mime_type"]) == "" { + merged["mime_type"] = contentType + } } nextData = append(nextData, merged) } @@ -57,43 +179,421 @@ func (s *Service) uploadGeneratedAssets(ctx context.Context, result map[string]a return next, nil } -func (s *Service) uploadBase64Image(ctx context.Context, b64 string, index int) (map[string]any, error) { - payload, err := base64.StdEncoding.DecodeString(stripDataURLPrefix(b64)) +func (s *Service) generatedAssetUploadPolicy(ctx context.Context) (generatedAssetUploadPolicy, error) { + settings, err := s.store.GetFileStorageSettings(ctx) if err != nil { - return nil, &clients.ClientError{Code: "upload_decode_failed", Message: err.Error(), Retryable: false} + if store.IsUndefinedDatabaseObject(err) { + return defaultGeneratedAssetUploadPolicy(), nil + } + return generatedAssetUploadPolicy{}, err + } + return generatedAssetUploadPolicyFromName(settings.ResultUploadPolicy), nil +} + +func generatedAssetUploadPolicyFromName(policyName string) generatedAssetUploadPolicy { + policyName = store.NormalizeFileStorageResultUploadPolicy(policyName) + switch policyName { + case store.FileStorageResultUploadPolicyUploadAll: + return generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: true} + case store.FileStorageResultUploadPolicyUploadNone: + return generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: false, StoreInlineMediaLocally: true} + default: + return defaultGeneratedAssetUploadPolicy() + } +} + +func generatedAssetNeedsChannelLookup(policy generatedAssetUploadPolicy, decisions []generatedAssetDecision) bool { + for _, decision := range decisions { + if decision.URL != nil { + return true + } + if decision.Inline != nil && !policy.StoreInlineMediaLocally { + return true + } + } + return false +} + +func (s *Service) uploadGeneratedAsset(ctx context.Context, taskID string, asset *generatedInlineAsset, index int, channels []store.FileStorageChannel, forceLocal bool) (map[string]any, string, string, string, error) { + contentType := resolvedGeneratedAssetContentType(asset.ContentType, asset.Kind, asset.Bytes) + kind := generatedAssetKindFromContentType(asset.Kind, contentType) + payload := FileUploadPayload{ + Bytes: asset.Bytes, + ContentType: contentType, + FileName: generatedAssetFileName(taskID, index, contentType, kind), + Scene: store.FileStorageSceneImageResult, + Source: "ai-gateway", + } + if forceLocal || len(channels) == 0 { + upload, err := s.storeFileLocally(payload, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir, localStaticGeneratedPathPrefix) + return upload, contentType, kind, "local_static_inline_media", err + } + upload, err := s.uploadFileWithFailover(ctx, payload, channels) + return upload, contentType, kind, "upload_inline_media", err +} + +func (s *Service) uploadGeneratedURLAsset(ctx context.Context, taskID string, asset *generatedURLAsset, index int, channels []store.FileStorageChannel) (map[string]any, string, string, string, error) { + payload, contentType, err := s.readGeneratedURLAsset(ctx, asset) + if err != nil { + return nil, "", "", "", err + } + contentType = resolvedGeneratedAssetContentType(firstNonEmptyString(contentType, asset.ContentType), asset.Kind, payload) + kind := generatedAssetKindFromContentType(asset.Kind, contentType) + uploadPayload := FileUploadPayload{ + Bytes: payload, + ContentType: contentType, + FileName: generatedAssetFileName(taskID, index, contentType, kind), + Scene: store.FileStorageSceneImageResult, + Source: "ai-gateway", + } + if len(channels) == 0 { + upload, err := s.storeFileLocally(uploadPayload, s.cfg.LocalGeneratedStorageDir, config.DefaultLocalGeneratedStorageDir, localStaticGeneratedPathPrefix) + return upload, contentType, kind, "local_static_url_media", err + } + upload, err := s.uploadFileWithFailover(ctx, uploadPayload, channels) + return upload, contentType, kind, "upload_url_media", err +} + +func (s *Service) storeFileLocally(payload FileUploadPayload, storageDir string, fallbackStorageDir string, pathPrefix string) (map[string]any, error) { + storageDir = strings.TrimSpace(storageDir) + if storageDir == "" { + storageDir = fallbackStorageDir + } + if err := os.MkdirAll(storageDir, 0o755); err != nil { + return nil, &clients.ClientError{Code: "local_static_store_failed", Message: err.Error(), Retryable: true} + } + fileName := filepath.Base(strings.TrimSpace(payload.FileName)) + if fileName == "" || fileName == "." || fileName == ".." || fileName == string(filepath.Separator) { + kind := generatedAssetKindFromContentType("", payload.ContentType) + fileName = generatedAssetFileName("generated", 0, payload.ContentType, kind) + } + targetPath := filepath.Join(storageDir, fileName) + file, err := os.OpenFile(targetPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o644) + if err != nil { + return nil, &clients.ClientError{Code: "local_static_store_failed", Message: err.Error(), Retryable: true} + } + _, writeErr := file.Write(payload.Bytes) + closeErr := file.Close() + if writeErr != nil { + _ = os.Remove(targetPath) + return nil, &clients.ClientError{Code: "local_static_store_failed", Message: writeErr.Error(), Retryable: true} + } + if closeErr != nil { + _ = os.Remove(targetPath) + return nil, &clients.ClientError{Code: "local_static_store_failed", Message: closeErr.Error(), Retryable: true} + } + return map[string]any{ + "url": s.localStaticFileURL(fileName, pathPrefix), + "fileName": fileName, + "contentType": payload.ContentType, + "size": len(payload.Bytes), + "storageChannel": map[string]any{ + "id": "local-static", + "channelKey": "local-static", + "name": "AI Gateway local static storage", + "provider": "local_static", + }, + }, nil +} + +func (s *Service) localStaticFileURL(fileName string, pathPrefix string) string { + if strings.TrimSpace(pathPrefix) == "" { + pathPrefix = localStaticUploadedPathPrefix + } + path := pathPrefix + url.PathEscape(filepath.Base(fileName)) + baseURL := strings.TrimRight(strings.TrimSpace(s.cfg.PublicBaseURL), "/") + if baseURL == "" { + return path + } + return baseURL + path +} + +func localStaticUploadFileName(originalName string, contentType string) string { + baseName := filepath.Base(strings.TrimSpace(originalName)) + originalExt := strings.ToLower(filepath.Ext(baseName)) + namePart := strings.TrimSuffix(baseName, originalExt) + namePart = sanitizeGeneratedAssetNamePart(namePart) + if namePart == "" { + namePart = "gateway-upload" + } + if len(namePart) > 48 { + namePart = namePart[:48] + } + return fmt.Sprintf("%s-%s%s", namePart, randomHexSuffix(6), uploadFileExtension(contentType, originalExt)) +} + +func uploadFileExtension(contentType string, fallbackExt string) string { + normalized := normalizeGeneratedContentType(contentType) + if generatedContentTypeIsMedia(normalized) { + return fileExtensionForContentType(normalized, generatedAssetKindFromContentType("", normalized)) + } + if normalized != "" && normalized != "application/octet-stream" { + if extensions, err := mime.ExtensionsByType(normalized); err == nil && len(extensions) > 0 { + if ext := sanitizeFileExtension(extensions[0]); ext != "" { + return ext + } + } + } + if ext := sanitizeFileExtension(fallbackExt); ext != "" { + return ext + } + if normalized == "application/json" { + return ".json" + } + if strings.HasPrefix(normalized, "text/") { + return ".txt" + } + return ".bin" +} + +func sanitizeFileExtension(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return "" + } + if !strings.HasPrefix(value, ".") { + value = "." + value + } + if len(value) > 16 { + return "" + } + for _, item := range value[1:] { + if (item >= 'a' && item <= 'z') || (item >= '0' && item <= '9') { + continue + } + return "" + } + return value +} + +func (s *Service) readGeneratedURLAsset(ctx context.Context, asset *generatedURLAsset) ([]byte, string, error) { + fetchURL, err := s.generatedAssetFetchURL(asset.URL) + if err != nil { + return nil, "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL, nil) + if err != nil { + return nil, "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, "", &clients.ClientError{Code: "upload_source_fetch_failed", Message: err.Error(), Retryable: true} + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + message := strings.TrimSpace(string(body)) + if message == "" { + message = "generated media source fetch failed" + } + return nil, "", &clients.ClientError{ + Code: "upload_source_fetch_failed", + Message: message, + StatusCode: resp.StatusCode, + Retryable: clients.HTTPRetryable(resp.StatusCode), + } + } + payload, err := io.ReadAll(io.LimitReader(resp.Body, maxGeneratedAssetFetchBytes+1)) + if err != nil { + return nil, "", &clients.ClientError{Code: "upload_source_read_failed", Message: err.Error(), StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)} + } + if len(payload) > maxGeneratedAssetFetchBytes { + return nil, "", &clients.ClientError{Code: "upload_source_too_large", Message: "generated media source exceeds upload fetch limit", StatusCode: resp.StatusCode, Retryable: false} + } + contentType := firstNonEmptyString(resp.Header.Get("Content-Type"), asset.ContentType) + return payload, strings.TrimSpace(strings.Split(contentType, ";")[0]), nil +} + +func (s *Service) generatedAssetFetchURL(raw string) (string, error) { + value := strings.TrimSpace(raw) + if value == "" { + return "", &clients.ClientError{Code: "upload_source_invalid_url", Message: "generated media source URL is empty", Retryable: false} + } + parsed, err := url.Parse(value) + if err != nil { + return "", &clients.ClientError{Code: "upload_source_invalid_url", Message: err.Error(), Retryable: false} + } + if parsed.IsAbs() { + if parsed.Scheme == "http" || parsed.Scheme == "https" { + return value, nil + } + return "", &clients.ClientError{Code: "upload_source_unsupported_url", Message: "unsupported generated media source URL scheme: " + parsed.Scheme, Retryable: false} + } + if strings.HasPrefix(value, "/") { + baseURL := generatedAssetLocalBaseURL(s.cfg.HTTPAddr) + if baseURL == "" { + return "", &clients.ClientError{Code: "upload_source_invalid_url", Message: "generated media source uses a relative URL without a local HTTP address", Retryable: false} + } + return baseURL + value, nil + } + return "", &clients.ClientError{Code: "upload_source_invalid_url", Message: "generated media source URL must be absolute or root-relative", Retryable: false} +} + +func generatedAssetLocalBaseURL(httpAddr string) string { + addr := strings.TrimSpace(httpAddr) + if addr == "" { + return "http://127.0.0.1:8088" + } + if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") { + return strings.TrimRight(addr, "/") + } + if strings.HasPrefix(addr, ":") { + return "http://127.0.0.1" + addr + } + if strings.Contains(addr, "://") { + return "" + } + return "http://" + strings.TrimRight(addr, "/") +} + +func (s *Service) UploadFile(ctx context.Context, payload FileUploadPayload) (map[string]any, error) { + if strings.TrimSpace(payload.Scene) == "" { + payload.Scene = store.FileStorageSceneUpload + } + channels, err := s.activeFileStorageChannels(ctx, payload.Scene) + if err != nil { + return nil, &clients.ClientError{Code: "upload_config_failed", Message: err.Error(), Retryable: true} + } + if len(channels) == 0 { + payload.FileName = localStaticUploadFileName(payload.FileName, payload.ContentType) + upload, err := s.storeFileLocally(payload, s.cfg.LocalUploadedStorageDir, config.DefaultLocalUploadedStorageDir, localStaticUploadedPathPrefix) + if err != nil { + return nil, err + } + upload["assetStorage"] = map[string]any{ + "scene": payload.Scene, + "source": firstNonEmptyString(payload.Source, "ai-gateway-openapi"), + "strategy": "local_static_upload", + } + return upload, nil + } + return s.uploadFileWithFailover(ctx, payload, channels) +} + +func (s *Service) activeFileStorageChannels(ctx context.Context, scene string) ([]store.FileStorageChannel, error) { + if s.store == nil { + return nil, nil + } + channels, err := s.store.ListEnabledFileStorageChannelsForScene(ctx, scene) + if err != nil && !store.IsUndefinedDatabaseObject(err) { + return nil, err + } + if len(channels) > 0 { + return channels, nil + } + return nil, nil +} + +func (s *Service) uploadFileWithFailover(ctx context.Context, payload FileUploadPayload, channels []store.FileStorageChannel) (map[string]any, error) { + var lastErr error + for _, channel := range channels { + upload, err := s.uploadWithChannelRetries(ctx, payload, channel) + if err == nil { + if s.store != nil { + _ = s.store.MarkFileStorageChannelSuccess(context.WithoutCancel(ctx), channel.ID) + } + return upload, nil + } + lastErr = err + if s.store != nil { + _ = s.store.MarkFileStorageChannelFailure(context.WithoutCancel(ctx), channel.ID, err.Error()) + } + } + if lastErr != nil { + return nil, lastErr + } + return nil, &clients.ClientError{Code: "upload_no_channel", Message: "no enabled file storage channel", Retryable: false} +} + +func (s *Service) uploadWithChannelRetries(ctx context.Context, payload FileUploadPayload, channel store.FileStorageChannel) (map[string]any, error) { + maxRetries, delays := uploadRetrySchedule(channel.RetryPolicy) + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + upload, err := s.uploadOnce(ctx, payload, channel) + if err == nil { + return upload, nil + } + lastErr = err + if attempt >= maxRetries || !clients.IsRetryable(err) { + break + } + delay := retryDelayForAttempt(attempt, delays) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + } + return nil, lastErr +} + +func (s *Service) uploadOnce(ctx context.Context, payload FileUploadPayload, channel store.FileStorageChannel) (map[string]any, error) { + if strings.ToLower(strings.TrimSpace(channel.Provider)) != "server_main_openapi" { + return nil, &clients.ClientError{Code: "upload_unsupported_channel", Message: "unsupported file storage channel: " + channel.Provider, Retryable: false} + } + uploadURL := strings.TrimSpace(channel.UploadURL) + if uploadURL == "" { + uploadURL = defaultServerMainOpenAPIUploadURL + } + apiKey := strings.TrimSpace(channel.APIKey) + if apiKey == "" { + return nil, &clients.ClientError{Code: "missing_credentials", Message: "file storage channel API key is required", Retryable: false} } var body bytes.Buffer writer := multipart.NewWriter(&body) - fileWriter, err := writer.CreateFormFile("file", fmt.Sprintf("gateway-result-%d.png", index+1)) + fileWriter, err := createUploadFormFile(writer, "file", firstNonEmptyString(payload.FileName, "upload.bin"), payload.ContentType) if err != nil { return nil, err } - if _, err := fileWriter.Write(payload); err != nil { + if _, err := fileWriter.Write(payload.Bytes); err != nil { return nil, err } - _ = writer.WriteField("source", "ai-gateway") + _ = writer.WriteField("source", firstNonEmptyString(payload.Source, "ai-gateway")) + _ = writer.WriteField("scene", firstNonEmptyString(payload.Scene, store.FileStorageSceneUpload)) if err := writer.Close(); err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(s.cfg.ServerMainBaseURL, "/")+"/v1/files/upload", &body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadURL, &body) if err != nil { return nil, err } req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Authorization", "Bearer "+s.cfg.ServerMainInternalToken) + if payload.ContentType != "" { + req.Header.Set("X-Upload-Content-Type", payload.ContentType) + } + req.Header.Set("Authorization", "Bearer "+apiKey) resp, err := http.DefaultClient.Do(req) if err != nil { return nil, &clients.ClientError{Code: "upload_network", Message: err.Error(), Retryable: true} } defer resp.Body.Close() - var decoded map[string]any - if err := json.NewDecoder(resp.Body).Decode(&decoded); err != nil { - return nil, &clients.ClientError{Code: "upload_invalid_response", Message: err.Error(), Retryable: false} + responseBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, &clients.ClientError{Code: "upload_read_failed", Message: readErr.Error(), StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)} } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, &clients.ClientError{Code: "upload_failed", Message: "server-main upload failed", StatusCode: resp.StatusCode, Retryable: resp.StatusCode >= 500} + message := strings.TrimSpace(string(responseBody)) + if message == "" { + message = "file upload failed" + } + return nil, &clients.ClientError{Code: "upload_failed", Message: message, StatusCode: resp.StatusCode, Retryable: clients.HTTPRetryable(resp.StatusCode)} } - return decoded, nil + var decoded map[string]any + if err := json.Unmarshal(responseBody, &decoded); err != nil { + return nil, &clients.ClientError{Code: "upload_invalid_response", Message: err.Error(), Retryable: false} + } + return normalizeUploadResponse(decoded, channel), nil +} + +func createUploadFormFile(writer *multipart.Writer, fieldName string, fileName string, contentType string) (io.Writer, error) { + header := make(textproto.MIMEHeader) + header.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeMultipartValue(fieldName), escapeMultipartValue(fileName))) + if strings.TrimSpace(contentType) != "" { + header.Set("Content-Type", strings.TrimSpace(contentType)) + } + return writer.CreatePart(header) +} + +func escapeMultipartValue(value string) string { + return strings.NewReplacer("\\", "\\\\", `"`, "\\\"").Replace(value) } func stripDataURLPrefix(value string) string { @@ -102,3 +602,594 @@ func stripDataURLPrefix(value string) string { } return value } + +func generatedAssetDecisionForItem(taskKind string, item map[string]any, policy generatedAssetUploadPolicy) (generatedAssetDecision, error) { + decision := generatedAssetDecision{} + urlKey, mediaURL := mediaURLSourceFromItem(item) + if mediaURL != "" { + if !policy.UploadURLMedia { + decision.StripKeys = inlineMediaKeys(item) + return decision, nil + } + contentType := mediaContentTypeFromItem(item) + kind := mediaKindForAsset(taskKind, item, urlKey, contentType) + decision.URL = &generatedURLAsset{ + URL: mediaURL, + ContentType: contentType, + Kind: kind, + SourceKey: urlKey, + } + decision.StripKeys = uniqueStringList(append(mediaURLKeys(item), inlineMediaKeys(item)...)) + return decision, nil + } + if !policy.UploadInlineMedia { + return decision, nil + } + asset, keys, err := inlineAssetFromItem(taskKind, item) + if err != nil { + return decision, err + } + if asset == nil { + return decision, nil + } + decision.Inline = asset + decision.StripKeys = keys + return decision, nil +} + +func inlineAssetFromItem(taskKind string, item map[string]any) (*generatedInlineAsset, []string, error) { + for _, key := range inlineMediaCandidateKeys() { + value, ok := item[key] + if !ok || value == nil { + continue + } + strictBase64 := inlineMediaKeyIsStrictBase64(key) + payload, contentType, ok, err := inlineMediaPayload(value, strictBase64) + if err != nil { + return nil, nil, err + } + if !ok { + continue + } + contentType = firstNonEmptyString(contentType, mediaContentTypeFromItem(item), defaultContentTypeForGeneratedAsset(mediaKindForAsset(taskKind, item, key, contentType))) + kind := mediaKindForAsset(taskKind, item, key, contentType) + return &generatedInlineAsset{ + Bytes: payload, + ContentType: contentType, + Kind: kind, + SourceKey: key, + }, inlineMediaKeys(item), nil + } + return nil, nil, nil +} + +func inlineMediaPayload(value any, strictBase64 bool) ([]byte, string, bool, error) { + switch typed := value.(type) { + case []byte: + if len(typed) == 0 { + return nil, "", false, nil + } + payload := make([]byte, len(typed)) + copy(payload, typed) + return payload, "", true, nil + case []any: + payload, ok := bytesFromNumberArray(typed) + return payload, "", ok, nil + case map[string]any: + if data, ok := typed["data"].([]any); ok { + payload, ok := bytesFromNumberArray(data) + return payload, firstNonEmptyString(stringFromAny(typed["mime_type"]), stringFromAny(typed["mimeType"])), ok, nil + } + if data, ok := typed["data"].([]byte); ok && len(data) > 0 { + payload := make([]byte, len(data)) + copy(payload, data) + return payload, firstNonEmptyString(stringFromAny(typed["mime_type"]), stringFromAny(typed["mimeType"])), true, nil + } + return nil, "", false, nil + case string: + return inlineMediaPayloadFromString(typed, strictBase64) + default: + return nil, "", false, nil + } +} + +func inlineMediaPayloadFromString(value string, strictBase64 bool) ([]byte, string, bool, error) { + raw := strings.TrimSpace(value) + if raw == "" || mediaURLString(raw) { + return nil, "", false, nil + } + if strings.HasPrefix(strings.ToLower(raw), "data:") { + contentType, encoded, ok, err := parseBase64DataURL(raw) + if err != nil || !ok { + return nil, "", false, err + } + payload, err := decodeBase64Payload(encoded) + if err != nil { + return nil, "", false, &clients.ClientError{Code: "upload_decode_failed", Message: err.Error(), Retryable: false} + } + return payload, contentType, true, nil + } + if !strictBase64 && len(raw) < 64 { + return nil, "", false, nil + } + payload, err := decodeBase64Payload(raw) + if err != nil { + if strictBase64 { + return nil, "", false, &clients.ClientError{Code: "upload_decode_failed", Message: err.Error(), Retryable: false} + } + return nil, "", false, nil + } + return payload, "", true, nil +} + +func parseBase64DataURL(value string) (string, string, bool, error) { + prefix, payload, ok := strings.Cut(value, ",") + if !ok { + return "", "", false, &clients.ClientError{Code: "upload_decode_failed", Message: "invalid data URL media payload", Retryable: false} + } + meta := strings.TrimPrefix(prefix, "data:") + meta = strings.TrimPrefix(meta, "DATA:") + parts := strings.Split(meta, ";") + contentType := strings.TrimSpace(parts[0]) + isBase64 := false + for _, part := range parts[1:] { + if strings.EqualFold(strings.TrimSpace(part), "base64") { + isBase64 = true + break + } + } + if !isBase64 { + return "", "", false, &clients.ClientError{Code: "upload_decode_failed", Message: "data URL media payload is not base64 encoded", Retryable: false} + } + return contentType, payload, true, nil +} + +func decodeBase64Payload(value string) ([]byte, error) { + normalized := strings.Map(func(r rune) rune { + switch r { + case '\n', '\r', '\t', ' ': + return -1 + default: + return r + } + }, stripDataURLPrefix(value)) + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, encoding := range encodings { + payload, err := encoding.DecodeString(normalized) + if err == nil && len(payload) > 0 { + return payload, nil + } + if err != nil { + lastErr = err + } + } + if lastErr == nil { + lastErr = fmt.Errorf("empty base64 payload") + } + return nil, lastErr +} + +func bytesFromNumberArray(values []any) ([]byte, bool) { + if len(values) == 0 { + return nil, false + } + payload := make([]byte, 0, len(values)) + for _, value := range values { + next, ok := byteFromAny(value) + if !ok { + return nil, false + } + payload = append(payload, next) + } + return payload, true +} + +func byteFromAny(value any) (byte, bool) { + switch typed := value.(type) { + case byte: + return typed, true + case int: + if typed >= 0 && typed <= 255 { + return byte(typed), true + } + case int64: + if typed >= 0 && typed <= 255 { + return byte(typed), true + } + case float64: + asInt := int(typed) + if typed == float64(asInt) && asInt >= 0 && asInt <= 255 { + return byte(asInt), true + } + } + return 0, false +} + +func inlineMediaKeys(item map[string]any) []string { + keys := []string{} + for _, key := range inlineMediaCandidateKeys() { + value, ok := item[key] + if !ok || value == nil { + continue + } + strictBase64 := inlineMediaKeyIsStrictBase64(key) + if strictBase64 && stringFromAny(value) != "" { + keys = append(keys, key) + continue + } + if _, _, ok, _ := inlineMediaPayload(value, strictBase64); ok { + keys = append(keys, key) + } + } + return uniqueStringList(keys) +} + +func inlineMediaCandidateKeys() []string { + return []string{ + "b64_json", + "image_base64", + "image_b64", + "video_base64", + "video_b64", + "base64", + "b64", + "url", + "image_url", + "imageUrl", + "video_url", + "videoUrl", + "output_url", + "outputUrl", + "output_video_url", + "outputVideoUrl", + "image", + "video", + "image_buffer", + "image_bytes", + "video_buffer", + "video_bytes", + "buffer", + "bytes", + "data", + } +} + +func inlineMediaKeyIsStrictBase64(key string) bool { + lower := strings.ToLower(key) + return lower == "b64_json" || lower == "base64" || lower == "b64" || strings.Contains(lower, "base64") || strings.Contains(lower, "_b64") +} + +func mediaURLSourceFromItem(item map[string]any) (string, string) { + for _, key := range mediaURLCandidateKeys() { + value := stringFromAny(item[key]) + if value != "" && mediaURLString(value) { + return key, value + } + } + return "", "" +} + +func mediaURLKeys(item map[string]any) []string { + keys := []string{} + for _, key := range mediaURLCandidateKeys() { + value := stringFromAny(item[key]) + if value != "" && mediaURLString(value) { + keys = append(keys, key) + } + } + return uniqueStringList(keys) +} + +func mediaURLCandidateKeys() []string { + return []string{"url", "image_url", "imageUrl", "video_url", "videoUrl", "output_url", "outputUrl", "output_video_url", "outputVideoUrl", "download_url", "downloadUrl", "file_url", "fileUrl"} +} + +func mediaURLString(value string) bool { + raw := strings.TrimSpace(value) + if raw == "" { + return false + } + lower := strings.ToLower(raw) + if strings.HasPrefix(lower, "data:") { + return false + } + return strings.HasPrefix(lower, "http://") || + strings.HasPrefix(lower, "https://") || + strings.HasPrefix(lower, "/") || + strings.Contains(lower, "://") +} + +func mediaContentTypeFromItem(item map[string]any) string { + return firstNonEmptyString( + stringFromAny(item["mime_type"]), + stringFromAny(item["mimeType"]), + stringFromAny(item["content_type"]), + stringFromAny(item["contentType"]), + ) +} + +func mediaKindForAsset(taskKind string, item map[string]any, sourceKey string, contentType string) string { + contentType = strings.ToLower(strings.TrimSpace(contentType)) + if strings.HasPrefix(contentType, "image/") { + return "image" + } + if strings.HasPrefix(contentType, "video/") { + return "video" + } + if strings.HasPrefix(contentType, "audio/") { + return "audio" + } + itemType := strings.ToLower(strings.TrimSpace(stringFromAny(item["type"]))) + if strings.Contains(itemType, "video") { + return "video" + } + if strings.Contains(itemType, "image") { + return "image" + } + key := strings.ToLower(sourceKey) + if strings.Contains(key, "video") { + return "video" + } + if strings.Contains(key, "image") { + return "image" + } + kind := strings.ToLower(strings.TrimSpace(taskKind)) + if strings.Contains(kind, "video") { + return "video" + } + if strings.Contains(kind, "image") { + return "image" + } + return "image" +} + +func defaultContentTypeForGeneratedAsset(kind string) string { + switch strings.ToLower(strings.TrimSpace(kind)) { + case "video": + return "video/mp4" + case "audio": + return "audio/mpeg" + default: + return "image/png" + } +} + +func resolvedGeneratedAssetContentType(declared string, kind string, payload []byte) string { + declared = normalizeGeneratedContentType(declared) + detected := detectGeneratedAssetContentType(payload) + if generatedContentTypeIsMedia(detected) { + return detected + } + if generatedContentTypeIsMedia(declared) { + return declared + } + return defaultContentTypeForGeneratedAsset(kind) +} + +func detectGeneratedAssetContentType(payload []byte) string { + if len(payload) == 0 { + return "" + } + return normalizeGeneratedContentType(http.DetectContentType(payload)) +} + +func normalizeGeneratedContentType(contentType string) string { + return strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) +} + +func generatedContentTypeIsMedia(contentType string) bool { + return strings.HasPrefix(contentType, "image/") || + strings.HasPrefix(contentType, "video/") || + strings.HasPrefix(contentType, "audio/") +} + +func generatedAssetKindFromContentType(fallback string, contentType string) string { + contentType = normalizeGeneratedContentType(contentType) + if strings.HasPrefix(contentType, "image/") { + return "image" + } + if strings.HasPrefix(contentType, "video/") { + return "video" + } + if strings.HasPrefix(contentType, "audio/") { + return "audio" + } + fallback = strings.ToLower(strings.TrimSpace(fallback)) + if fallback != "" { + return fallback + } + return "image" +} + +func generatedAssetFileName(taskID string, index int, contentType string, kind string) string { + token := sanitizeGeneratedAssetNamePart(taskID) + if token == "" { + token = fmt.Sprintf("%d", time.Now().UTC().UnixNano()) + } + if len(token) > 32 { + token = token[:32] + } + return fmt.Sprintf("gateway-result-%s-%02d-%s%s", token, index+1, randomHexSuffix(6), fileExtensionForContentType(contentType, kind)) +} + +func sanitizeGeneratedAssetNamePart(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + var builder strings.Builder + for _, item := range value { + if (item >= 'a' && item <= 'z') || (item >= '0' && item <= '9') || item == '-' || item == '_' { + builder.WriteRune(item) + } + } + return strings.Trim(builder.String(), "-_") +} + +func randomHexSuffix(byteCount int) string { + if byteCount <= 0 { + byteCount = 6 + } + payload := make([]byte, byteCount) + if _, err := rand.Read(payload); err == nil { + return hex.EncodeToString(payload) + } + return fmt.Sprintf("%d", time.Now().UTC().UnixNano()) +} + +func fileExtensionForContentType(contentType string, kind string) string { + normalized := strings.ToLower(strings.TrimSpace(strings.Split(contentType, ";")[0])) + switch normalized { + case "image/jpeg", "image/jpg": + return ".jpg" + case "image/webp": + return ".webp" + case "image/gif": + return ".gif" + case "image/avif": + return ".avif" + case "image/bmp": + return ".bmp" + case "image/svg+xml": + return ".svg" + case "video/webm": + return ".webm" + case "video/quicktime": + return ".mov" + case "video/mp4": + return ".mp4" + case "audio/wav", "audio/x-wav": + return ".wav" + case "audio/ogg": + return ".ogg" + case "audio/mpeg", "audio/mp3": + return ".mp3" + case "image/png": + return ".png" + } + if strings.EqualFold(kind, "video") { + return ".mp4" + } + if strings.EqualFold(kind, "audio") { + return ".mp3" + } + return ".png" +} + +func uniqueStringList(values []string) []string { + seen := map[string]bool{} + out := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" || seen[value] { + continue + } + seen[value] = true + out = append(out, value) + } + return out +} + +func uploadRetrySchedule(policy map[string]any) (int, []time.Duration) { + if policy == nil { + policy = defaultUploadRetryPolicy() + } + if enabled, ok := policy["enabled"].(bool); ok && !enabled { + return 0, nil + } + maxRetries := intFromPolicy(policy, "maxRetries") + if maxRetries <= 0 { + maxRetries = 3 + } + delays := uploadRetryDelays(policy["backoffSeconds"]) + if len(delays) == 0 { + delays = []time.Duration{60 * time.Second, 120 * time.Second, 180 * time.Second} + } + return maxRetries, delays +} + +func uploadRetryDelays(value any) []time.Duration { + items, ok := value.([]any) + if !ok { + return nil + } + delays := make([]time.Duration, 0, len(items)) + for _, item := range items { + seconds := int(floatFromAny(item)) + if seconds > 0 { + delays = append(delays, time.Duration(seconds)*time.Second) + } + } + return delays +} + +func retryDelayForAttempt(attempt int, delays []time.Duration) time.Duration { + if attempt < len(delays) { + return delays[attempt] + } + return delays[len(delays)-1] +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func defaultUploadRetryPolicy() map[string]any { + return map[string]any{ + "enabled": true, + "maxRetries": 3, + "backoffSeconds": []any{60, 120, 180}, + "strategy": "exponential", + } +} + +func normalizeUploadResponse(decoded map[string]any, channel store.FileStorageChannel) map[string]any { + if decoded == nil { + decoded = map[string]any{} + } + if stringFromAny(decoded["url"]) == "" { + if urlValue := uploadResponseURL(decoded); urlValue != "" { + decoded["url"] = urlValue + } + } + decoded["storageChannel"] = map[string]any{ + "id": channel.ID, + "channelKey": channel.ChannelKey, + "name": channel.Name, + "provider": channel.Provider, + } + return decoded +} + +func uploadResponseURL(decoded map[string]any) string { + for _, key := range []string{"url", "fileUrl", "file_url"} { + if value := stringFromAny(decoded[key]); value != "" { + return value + } + } + for _, key := range []string{"data", "file", "result"} { + if nested, ok := decoded[key].(map[string]any); ok { + if value := uploadResponseURL(nested); value != "" { + return value + } + } + if items, ok := decoded[key].([]any); ok && len(items) > 0 { + if nested, ok := items[0].(map[string]any); ok { + if value := uploadResponseURL(nested); value != "" { + return value + } + } + } + } + return "" +} diff --git a/apps/api/internal/runner/upload_test.go b/apps/api/internal/runner/upload_test.go new file mode 100644 index 0000000..1354cf5 --- /dev/null +++ b/apps/api/internal/runner/upload_test.go @@ -0,0 +1,279 @@ +package runner + +import ( + "bytes" + "context" + "encoding/base64" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestGeneratedAssetDecisionSkipsURLResultAndStripsInlinePayload(t *testing.T) { + item := map[string]any{ + "b64_json": base64.StdEncoding.EncodeToString([]byte("inline image")), + "url": "https://cdn.example.com/generated.png", + } + + decision, err := generatedAssetDecisionForItem("images.generations", item, defaultGeneratedAssetUploadPolicy()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decision.Inline != nil { + t.Fatalf("URL media should not be uploaded by the default policy") + } + if !containsString(decision.StripKeys, "b64_json") { + t.Fatalf("inline payload should be stripped when URL is already available: %+v", decision.StripKeys) + } +} + +func TestGeneratedAssetDecisionUploadsInlineImageBase64(t *testing.T) { + item := map[string]any{ + "b64_json": base64.StdEncoding.EncodeToString([]byte("inline image")), + "mime_type": "image/jpeg", + } + + decision, err := generatedAssetDecisionForItem("images.generations", item, defaultGeneratedAssetUploadPolicy()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decision.Inline == nil { + t.Fatalf("expected inline image to be uploaded") + } + if decision.Inline.Kind != "image" || decision.Inline.ContentType != "image/jpeg" { + t.Fatalf("unexpected inline image metadata: %+v", decision.Inline) + } + if !containsString(decision.StripKeys, "b64_json") { + t.Fatalf("uploaded inline payload should be stripped: %+v", decision.StripKeys) + } +} + +func TestGeneratedAssetDecisionUploadsInlineVideoBuffer(t *testing.T) { + item := map[string]any{ + "type": "video", + "video_buffer": []any{float64(0), float64(1), float64(2), float64(3)}, + } + + decision, err := generatedAssetDecisionForItem("videos.generations", item, defaultGeneratedAssetUploadPolicy()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decision.Inline == nil { + t.Fatalf("expected inline video buffer to be uploaded") + } + if decision.Inline.Kind != "video" || decision.Inline.ContentType != "video/mp4" { + t.Fatalf("unexpected inline video metadata: %+v", decision.Inline) + } + if !containsString(decision.StripKeys, "video_buffer") { + t.Fatalf("uploaded video buffer should be stripped: %+v", decision.StripKeys) + } +} + +func TestGeneratedAssetDecisionUploadsDataURL(t *testing.T) { + item := map[string]any{ + "url": "data:image/webp;base64," + base64.StdEncoding.EncodeToString([]byte("inline webp")), + } + + decision, err := generatedAssetDecisionForItem("images.generations", item, defaultGeneratedAssetUploadPolicy()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decision.Inline == nil { + t.Fatalf("expected data URL to be uploaded") + } + if decision.Inline.SourceKey != "url" || decision.Inline.ContentType != "image/webp" { + t.Fatalf("unexpected data URL metadata: %+v", decision.Inline) + } + if !containsString(decision.StripKeys, "url") { + t.Fatalf("uploaded data URL field should be stripped: %+v", decision.StripKeys) + } +} + +func TestGeneratedAssetDecisionUploadsURLWhenPolicyUploadAll(t *testing.T) { + item := map[string]any{ + "type": "video", + "video_url": "https://cdn.example.com/generated.mp4", + } + + decision, err := generatedAssetDecisionForItem("videos.generations", item, generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decision.URL == nil { + t.Fatalf("expected URL media to be uploaded") + } + if decision.URL.Kind != "video" || decision.URL.SourceKey != "video_url" { + t.Fatalf("unexpected URL media metadata: %+v", decision.URL) + } + if !containsString(decision.StripKeys, "video_url") { + t.Fatalf("uploaded URL field should be stripped: %+v", decision.StripKeys) + } +} + +func TestGeneratedAssetDecisionStoresInlineLocallyWhenPolicyUploadNone(t *testing.T) { + item := map[string]any{ + "b64_json": base64.StdEncoding.EncodeToString([]byte("inline image")), + } + + decision, err := generatedAssetDecisionForItem("images.generations", item, generatedAssetUploadPolicyFromName(store.FileStorageResultUploadPolicyUploadNone)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if decision.Inline == nil || decision.URL != nil { + t.Fatalf("upload_none should still turn inline payloads into static URLs: %+v", decision) + } + if !containsString(decision.StripKeys, "b64_json") { + t.Fatalf("inline payload should be stripped before persistence: %+v", decision.StripKeys) + } +} + +func TestGeneratedAssetUploadPolicyFromName(t *testing.T) { + tests := []struct { + name string + policyName string + want generatedAssetUploadPolicy + }{ + { + name: "default", + policyName: store.FileStorageResultUploadPolicyDefault, + want: generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: false, StoreInlineMediaLocally: false}, + }, + { + name: "upload all", + policyName: store.FileStorageResultUploadPolicyUploadAll, + want: generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: true, StoreInlineMediaLocally: false}, + }, + { + name: "upload none", + policyName: store.FileStorageResultUploadPolicyUploadNone, + want: generatedAssetUploadPolicy{UploadInlineMedia: true, UploadURLMedia: false, StoreInlineMediaLocally: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := generatedAssetUploadPolicyFromName(tt.policyName) + if got != tt.want { + t.Fatalf("unexpected policy: got %+v, want %+v", got, tt.want) + } + }) + } +} + +func TestResolvedGeneratedAssetContentTypePrefersDetectedMedia(t *testing.T) { + pngPayload := []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a, 0, 0, 0, 0} + + contentType := resolvedGeneratedAssetContentType("image/jpeg", "image", pngPayload) + if contentType != "image/png" { + t.Fatalf("expected detected PNG content type, got %s", contentType) + } + if extension := fileExtensionForContentType(contentType, "image"); extension != ".png" { + t.Fatalf("expected PNG extension, got %s", extension) + } +} + +func TestResolvedGeneratedAssetContentTypeKeepsDeclaredMediaWhenDetectionIsGeneric(t *testing.T) { + contentType := resolvedGeneratedAssetContentType("image/webp", "image", []byte("not enough media bytes")) + if contentType != "image/webp" { + t.Fatalf("expected declared webp content type, got %s", contentType) + } +} + +func TestGeneratedAssetFileNameIsUniqueAndTyped(t *testing.T) { + first := generatedAssetFileName("663e19cd4fa9d8078385c7c9", 0, "image/png", "image") + second := generatedAssetFileName("663e19cd4fa9d8078385c7c9", 0, "image/png", "image") + if first == second { + t.Fatalf("expected generated file names to be unique, both were %s", first) + } + if !strings.HasPrefix(first, "gateway-result-663e19cd4fa9d8078385c7c9-01-") || !strings.HasSuffix(first, ".png") { + t.Fatalf("unexpected generated file name: %s", first) + } +} + +func TestUploadGeneratedAssetStoresLocalWhenNoChannels(t *testing.T) { + storageDir := t.TempDir() + service := &Service{cfg: config.Config{LocalGeneratedStorageDir: storageDir}} + payload := []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a, 0, 0, 0, 0} + asset := &generatedInlineAsset{ + Bytes: payload, + ContentType: "image/jpeg", + Kind: "image", + SourceKey: "b64_json", + } + + upload, contentType, kind, strategy, err := service.uploadGeneratedAsset(context.Background(), "task-123", asset, 0, nil, false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if contentType != "image/png" || kind != "image" || strategy != "local_static_inline_media" { + t.Fatalf("unexpected local upload metadata: contentType=%s kind=%s strategy=%s", contentType, kind, strategy) + } + urlValue := stringFromAny(upload["url"]) + if !strings.HasPrefix(urlValue, "/static/generated/gateway-result-task-123-01-") || !strings.HasSuffix(urlValue, ".png") { + t.Fatalf("unexpected local static URL: %s", urlValue) + } + entries, err := os.ReadDir(storageDir) + if err != nil { + t.Fatalf("failed to read local static dir: %v", err) + } + if len(entries) != 1 || !strings.HasSuffix(entries[0].Name(), ".png") { + t.Fatalf("expected one PNG file in local static dir, got %+v", entries) + } + stored, err := os.ReadFile(filepath.Join(storageDir, entries[0].Name())) + if err != nil { + t.Fatalf("failed to read local static file: %v", err) + } + if !bytes.Equal(stored, payload) { + t.Fatalf("stored payload does not match source payload") + } +} + +func TestUploadFileStoresLocalWhenNoChannels(t *testing.T) { + storageDir := t.TempDir() + service := &Service{cfg: config.Config{ + LocalUploadedStorageDir: storageDir, + ServerMainBaseURL: "http://127.0.0.1:1", + ServerMainInternalToken: "change-me", + }} + payload := []byte("%PDF-1.4") + + upload, err := service.UploadFile(context.Background(), FileUploadPayload{ + Bytes: payload, + ContentType: "application/pdf", + FileName: "用户文件.png", + Source: "playground", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + urlValue := stringFromAny(upload["url"]) + if !strings.HasPrefix(urlValue, "/static/uploaded/") || !strings.HasSuffix(urlValue, ".pdf") { + t.Fatalf("unexpected uploaded local static URL: %s", urlValue) + } + storageChannel, _ := upload["storageChannel"].(map[string]any) + if stringFromAny(storageChannel["provider"]) != "local_static" { + t.Fatalf("expected local static provider metadata, got %+v", upload["storageChannel"]) + } + assetStorage, _ := upload["assetStorage"].(map[string]any) + if stringFromAny(assetStorage["strategy"]) != "local_static_upload" || stringFromAny(assetStorage["scene"]) != store.FileStorageSceneUpload { + t.Fatalf("unexpected upload asset storage metadata: %+v", assetStorage) + } + entries, err := os.ReadDir(storageDir) + if err != nil { + t.Fatalf("failed to read uploaded static dir: %v", err) + } + if len(entries) != 1 || !strings.HasSuffix(entries[0].Name(), ".pdf") { + t.Fatalf("expected one PDF file in uploaded static dir, got %+v", entries) + } + stored, err := os.ReadFile(filepath.Join(storageDir, entries[0].Name())) + if err != nil { + t.Fatalf("failed to read uploaded static file: %v", err) + } + if !bytes.Equal(stored, payload) { + t.Fatalf("stored uploaded payload does not match source payload") + } +} diff --git a/apps/api/internal/store/file_storage_channels.go b/apps/api/internal/store/file_storage_channels.go new file mode 100644 index 0000000..7bcba18 --- /dev/null +++ b/apps/api/internal/store/file_storage_channels.go @@ -0,0 +1,499 @@ +package store + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/jackc/pgx/v5" +) + +const defaultServerMainUploadURL = "http://127.0.0.1:3001/v1/files/upload" + +const ( + FileStorageSceneUpload = "upload" + FileStorageSceneImageResult = "image_result" +) + +const ( + FileStorageResultUploadPolicyDefault = "default" + FileStorageResultUploadPolicyUploadAll = "upload_all" + FileStorageResultUploadPolicyUploadNone = "upload_none" +) + +const SystemSettingFileStorage = "file_storage" + +const fileStorageChannelColumns = ` +id::text, channel_key, name, provider, COALESCE(upload_url, ''), credentials, +config, retry_policy, priority, status, COALESCE(last_error, ''), +COALESCE(last_failed_at::text, ''), COALESCE(last_succeeded_at::text, ''), +created_at, updated_at` + +type FileStorageChannel struct { + ID string `json:"id"` + ChannelKey string `json:"channelKey"` + Name string `json:"name"` + Provider string `json:"provider"` + UploadURL string `json:"uploadUrl,omitempty"` + APIKey string `json:"-"` + CredentialsPreview map[string]any `json:"credentialsPreview,omitempty"` + Scenes []string `json:"scenes,omitempty"` + Config map[string]any `json:"config,omitempty"` + RetryPolicy map[string]any `json:"retryPolicy,omitempty"` + Priority int `json:"priority"` + Status string `json:"status"` + LastError string `json:"lastError,omitempty"` + LastFailedAt string `json:"lastFailedAt,omitempty"` + LastSucceededAt string `json:"lastSucceededAt,omitempty"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type FileStorageChannelInput struct { + ChannelKey string `json:"channelKey"` + Name string `json:"name"` + Provider string `json:"provider"` + UploadURL string `json:"uploadUrl"` + APIKey *string `json:"apiKey"` + Scenes []string `json:"scenes"` + Config map[string]any `json:"config"` + RetryPolicy map[string]any `json:"retryPolicy"` + Priority int `json:"priority"` + Status string `json:"status"` +} + +type FileStorageSettings struct { + ResultUploadPolicy string `json:"resultUploadPolicy"` +} + +type FileStorageSettingsInput struct { + ResultUploadPolicy string `json:"resultUploadPolicy"` +} + +type fileStorageChannelScanner interface { + Scan(dest ...any) error +} + +func (s *Store) ListFileStorageChannels(ctx context.Context) ([]FileStorageChannel, error) { + rows, err := s.pool.Query(ctx, ` +SELECT `+fileStorageChannelColumns+` +FROM file_storage_channels +WHERE deleted_at IS NULL +ORDER BY priority ASC, created_at ASC`) + if err != nil { + return nil, err + } + defer rows.Close() + items := make([]FileStorageChannel, 0) + for rows.Next() { + item, err := scanFileStorageChannel(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, rows.Err() +} + +func (s *Store) ListEnabledFileStorageChannels(ctx context.Context) ([]FileStorageChannel, error) { + return s.listEnabledFileStorageChannels(ctx, "") +} + +func (s *Store) ListEnabledFileStorageChannelsForScene(ctx context.Context, scene string) ([]FileStorageChannel, error) { + return s.listEnabledFileStorageChannels(ctx, normalizeFileStorageScene(scene)) +} + +func (s *Store) listEnabledFileStorageChannels(ctx context.Context, scene string) ([]FileStorageChannel, error) { + rows, err := s.pool.Query(ctx, ` +SELECT `+fileStorageChannelColumns+` +FROM file_storage_channels +WHERE deleted_at IS NULL + AND status = 'enabled' + AND ( + $1 = '' + OR NOT (config ? 'scenes') + OR jsonb_typeof(config->'scenes') <> 'array' + OR (config->'scenes') ? $1 + ) +ORDER BY priority ASC, created_at ASC`, scene) + if err != nil { + return nil, err + } + defer rows.Close() + items := make([]FileStorageChannel, 0) + for rows.Next() { + item, err := scanFileStorageChannel(rows) + if err != nil { + return nil, err + } + items = append(items, item) + } + return items, rows.Err() +} + +func (s *Store) GetFileStorageChannel(ctx context.Context, id string) (FileStorageChannel, error) { + return scanFileStorageChannel(s.pool.QueryRow(ctx, ` +SELECT `+fileStorageChannelColumns+` +FROM file_storage_channels +WHERE id = $1::uuid + AND deleted_at IS NULL`, id)) +} + +func (s *Store) CreateFileStorageChannel(ctx context.Context, input FileStorageChannelInput) (FileStorageChannel, error) { + input = normalizeFileStorageChannelInput(input) + credentials, _ := json.Marshal(credentialsFromFileStorageInput(input)) + config, _ := json.Marshal(configFromFileStorageInput(input)) + retryPolicy, _ := json.Marshal(defaultFileStorageRetryPolicyIfEmpty(input.RetryPolicy)) + + return scanFileStorageChannel(s.pool.QueryRow(ctx, ` +INSERT INTO file_storage_channels ( + channel_key, name, provider, upload_url, credentials, config, retry_policy, priority, status +) +VALUES ($1, $2, $3, NULLIF($4, ''), $5, $6, $7, $8, $9) +RETURNING `+fileStorageChannelColumns, + input.ChannelKey, + input.Name, + input.Provider, + input.UploadURL, + credentials, + config, + retryPolicy, + input.Priority, + input.Status, + )) +} + +func (s *Store) UpdateFileStorageChannel(ctx context.Context, id string, input FileStorageChannelInput) (FileStorageChannel, error) { + input = normalizeFileStorageChannelInput(input) + replaceCredentials := input.APIKey != nil + credentials, _ := json.Marshal(credentialsFromFileStorageInput(input)) + config, _ := json.Marshal(configFromFileStorageInput(input)) + retryPolicy, _ := json.Marshal(defaultFileStorageRetryPolicyIfEmpty(input.RetryPolicy)) + + return scanFileStorageChannel(s.pool.QueryRow(ctx, ` +UPDATE file_storage_channels +SET channel_key = $2, + name = $3, + provider = $4, + upload_url = NULLIF($5, ''), + credentials = CASE WHEN $6::boolean THEN $7 ELSE credentials END, + config = $8, + retry_policy = $9, + priority = $10, + status = $11, + updated_at = now() +WHERE id = $1::uuid + AND deleted_at IS NULL +RETURNING `+fileStorageChannelColumns, + id, + input.ChannelKey, + input.Name, + input.Provider, + input.UploadURL, + replaceCredentials, + credentials, + config, + retryPolicy, + input.Priority, + input.Status, + )) +} + +func (s *Store) DeleteFileStorageChannel(ctx context.Context, id string) error { + result, err := s.pool.Exec(ctx, ` +UPDATE file_storage_channels +SET deleted_at = now(), + status = 'disabled', + updated_at = now() +WHERE id = $1::uuid + AND deleted_at IS NULL`, id) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return pgx.ErrNoRows + } + return nil +} + +func (s *Store) MarkFileStorageChannelFailure(ctx context.Context, id string, message string) error { + if strings.TrimSpace(id) == "" { + return nil + } + _, err := s.pool.Exec(ctx, ` +UPDATE file_storage_channels +SET last_error = NULLIF($2, ''), + last_failed_at = now(), + updated_at = now() +WHERE id = $1::uuid + AND deleted_at IS NULL`, id, strings.TrimSpace(message)) + return err +} + +func (s *Store) MarkFileStorageChannelSuccess(ctx context.Context, id string) error { + if strings.TrimSpace(id) == "" { + return nil + } + _, err := s.pool.Exec(ctx, ` +UPDATE file_storage_channels +SET last_error = NULL, + last_succeeded_at = now(), + updated_at = now() +WHERE id = $1::uuid + AND deleted_at IS NULL`, id) + return err +} + +func scanFileStorageChannel(scanner fileStorageChannelScanner) (FileStorageChannel, error) { + var item FileStorageChannel + var credentials []byte + var config []byte + var retryPolicy []byte + if err := scanner.Scan( + &item.ID, + &item.ChannelKey, + &item.Name, + &item.Provider, + &item.UploadURL, + &credentials, + &config, + &retryPolicy, + &item.Priority, + &item.Status, + &item.LastError, + &item.LastFailedAt, + &item.LastSucceededAt, + &item.CreatedAt, + &item.UpdatedAt, + ); err != nil { + return FileStorageChannel{}, err + } + credentialObject := decodeObject(credentials) + item.APIKey = stringFromObject(credentialObject, "apiKey") + item.CredentialsPreview = maskCredentialsPreview(credentials) + configObject := decodeObject(config) + item.Scenes = fileStorageScenesFromConfig(configObject) + item.Config = fileStorageConfigWithoutManagedFields(configObject) + item.RetryPolicy = decodeObject(retryPolicy) + return item, nil +} + +func normalizeFileStorageChannelInput(input FileStorageChannelInput) FileStorageChannelInput { + input.ChannelKey = strings.TrimSpace(input.ChannelKey) + input.Name = strings.TrimSpace(input.Name) + input.Provider = strings.ToLower(strings.TrimSpace(input.Provider)) + input.UploadURL = strings.TrimSpace(input.UploadURL) + if input.APIKey != nil { + apiKey := strings.TrimSpace(*input.APIKey) + input.APIKey = &apiKey + } + input.Scenes = normalizeFileStorageScenes(input.Scenes) + input.Status = strings.ToLower(strings.TrimSpace(input.Status)) + if input.Provider == "" { + input.Provider = "server_main_openapi" + } + if input.Provider == "server_main_openapi" && input.UploadURL == "" { + input.UploadURL = defaultServerMainUploadURL + } + if input.Status == "" { + input.Status = "disabled" + } + if input.Priority <= 0 { + input.Priority = 100 + } + return input +} + +func credentialsFromFileStorageInput(input FileStorageChannelInput) map[string]any { + apiKey := fileStorageInputAPIKey(input) + if apiKey == "" { + return map[string]any{} + } + return map[string]any{"apiKey": apiKey} +} + +func fileStorageInputAPIKey(input FileStorageChannelInput) string { + if input.APIKey == nil { + return "" + } + return strings.TrimSpace(*input.APIKey) +} + +func configFromFileStorageInput(input FileStorageChannelInput) map[string]any { + config := map[string]any{} + for key, value := range emptyObjectIfNil(input.Config) { + config[key] = value + } + config["scenes"] = normalizeFileStorageScenes(input.Scenes) + return config +} + +func fileStorageConfigWithoutManagedFields(config map[string]any) map[string]any { + out := map[string]any{} + for key, value := range config { + if key == "scenes" || key == "resultUploadPolicy" { + continue + } + out[key] = value + } + if len(out) == 0 { + return nil + } + return out +} + +func DefaultFileStorageSettings() FileStorageSettings { + return FileStorageSettings{ResultUploadPolicy: FileStorageResultUploadPolicyDefault} +} + +func (s *Store) GetFileStorageSettings(ctx context.Context) (FileStorageSettings, error) { + var value []byte + err := s.pool.QueryRow(ctx, ` +SELECT value +FROM system_settings +WHERE setting_key = $1`, SystemSettingFileStorage).Scan(&value) + if err != nil { + if IsNotFound(err) { + return DefaultFileStorageSettings(), nil + } + return FileStorageSettings{}, err + } + return fileStorageSettingsFromValue(decodeObject(value)), nil +} + +func (s *Store) UpdateFileStorageSettings(ctx context.Context, input FileStorageSettingsInput) (FileStorageSettings, error) { + settings := FileStorageSettings{ResultUploadPolicy: NormalizeFileStorageResultUploadPolicy(input.ResultUploadPolicy)} + value, _ := json.Marshal(settings) + var saved []byte + err := s.upsertFileStorageSettings(ctx, value, &saved) + if err != nil && IsUndefinedDatabaseObject(err) { + if ensureErr := s.ensureSystemSettingsTable(ctx); ensureErr != nil { + return FileStorageSettings{}, ensureErr + } + err = s.upsertFileStorageSettings(ctx, value, &saved) + } + if err != nil { + return FileStorageSettings{}, err + } + return fileStorageSettingsFromValue(decodeObject(saved)), nil +} + +func (s *Store) upsertFileStorageSettings(ctx context.Context, value []byte, saved *[]byte) error { + return s.pool.QueryRow(ctx, ` +INSERT INTO system_settings (setting_key, value) +VALUES ($1, $2) +ON CONFLICT (setting_key) +DO UPDATE SET value = EXCLUDED.value, updated_at = now() +RETURNING value`, SystemSettingFileStorage, value).Scan(saved) +} + +func (s *Store) ensureSystemSettingsTable(ctx context.Context) error { + _, err := s.pool.Exec(ctx, ` +CREATE TABLE IF NOT EXISTS system_settings ( + setting_key text PRIMARY KEY, + value jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now() +)`) + return err +} + +func fileStorageSettingsFromValue(value map[string]any) FileStorageSettings { + settings := DefaultFileStorageSettings() + if value == nil { + return settings + } + settings.ResultUploadPolicy = NormalizeFileStorageResultUploadPolicy(stringFromAny(value["resultUploadPolicy"])) + return settings +} + +func NormalizeFileStorageResultUploadPolicy(policy string) string { + normalized := strings.ToLower(strings.TrimSpace(policy)) + normalized = strings.ReplaceAll(normalized, "-", "_") + switch normalized { + case "", "default", "non_link_only", "inline_only", "nonlink_only", "non_link": + return FileStorageResultUploadPolicyDefault + case "upload_all", "all", "always", "all_upload": + return FileStorageResultUploadPolicyUploadAll + case "upload_none", "none", "never", "disabled", "no_upload", "skip", "skip_all": + return FileStorageResultUploadPolicyUploadNone + default: + return FileStorageResultUploadPolicyDefault + } +} + +func fileStorageScenesFromConfig(config map[string]any) []string { + if config == nil { + return defaultFileStorageScenes() + } + raw, ok := config["scenes"] + if !ok { + return defaultFileStorageScenes() + } + items, ok := raw.([]any) + if !ok { + return defaultFileStorageScenes() + } + scenes := make([]string, 0, len(items)) + for _, item := range items { + if value, ok := item.(string); ok { + scenes = append(scenes, value) + } + } + return normalizeFileStorageScenes(scenes) +} + +func normalizeFileStorageScenes(scenes []string) []string { + seen := map[string]bool{} + out := make([]string, 0, len(scenes)) + for _, item := range scenes { + scene := normalizeFileStorageScene(item) + if scene == "" || seen[scene] { + continue + } + seen[scene] = true + out = append(out, scene) + } + if len(out) == 0 { + return defaultFileStorageScenes() + } + return out +} + +func normalizeFileStorageScene(scene string) string { + return strings.ToLower(strings.TrimSpace(scene)) +} + +func defaultFileStorageScenes() []string { + return []string{FileStorageSceneUpload, FileStorageSceneImageResult} +} + +func defaultFileStorageRetryPolicyIfEmpty(policy map[string]any) map[string]any { + if len(policy) > 0 { + return policy + } + return map[string]any{ + "enabled": true, + "maxRetries": 3, + "backoffSeconds": []any{60, 120, 180}, + "strategy": "exponential", + } +} + +func stringFromObject(value map[string]any, key string) string { + if value == nil { + return "" + } + raw, _ := value[key].(string) + return strings.TrimSpace(raw) +} + +func stringFromAny(value any) string { + switch typed := value.(type) { + case string: + return strings.TrimSpace(typed) + default: + return "" + } +} diff --git a/apps/api/internal/store/rate_limits.go b/apps/api/internal/store/rate_limits.go index e564e61..af61f2b 100644 --- a/apps/api/internal/store/rate_limits.go +++ b/apps/api/internal/store/rate_limits.go @@ -31,9 +31,18 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID } if reservation.Metric == "" || reservation.Amount > reservation.Limit { return RateLimitResult{}, &RateLimitExceededError{ - Metric: reservation.Metric, - Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit), - Retryable: false, + ScopeType: reservation.ScopeType, + ScopeKey: reservation.ScopeKey, + ScopeName: reservation.ScopeName, + ScopeMetadata: reservation.ScopeMetadata, + Metric: reservation.Metric, + Limit: reservation.Limit, + Amount: reservation.Amount, + Projected: reservation.Amount, + WindowSeconds: reservation.WindowSeconds, + Policy: reservation.Policy, + Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit), + Retryable: false, } } if reservation.WindowSeconds <= 0 { @@ -78,10 +87,22 @@ WHERE scope_type = $1 } if active+reservation.Amount > reservation.Limit { return "", &RateLimitExceededError{ - Metric: reservation.Metric, - Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit), - RetryAfter: concurrencyRetryAfter(nextAvailableAt), - Retryable: true, + ScopeType: reservation.ScopeType, + ScopeKey: reservation.ScopeKey, + ScopeName: reservation.ScopeName, + ScopeMetadata: reservation.ScopeMetadata, + Metric: reservation.Metric, + Limit: reservation.Limit, + Amount: reservation.Amount, + Current: active, + Used: active, + Projected: active + reservation.Amount, + WindowSeconds: reservation.WindowSeconds, + ResetAt: nextAvailableAt, + Policy: reservation.Policy, + Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit), + RetryAfter: concurrencyRetryAfter(nextAvailableAt), + Retryable: true, } } var leaseID string @@ -135,11 +156,13 @@ RETURNING window_start`, if err != nil { if errors.Is(err, pgx.ErrNoRows) { resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second) + currentUsed := 0.0 + currentReserved := 0.0 _ = tx.QueryRow(ctx, ` WITH bounds AS ( SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start ) -SELECT counters.reset_at +SELECT counters.used_value::float8, counters.reserved_value::float8, counters.reset_at FROM gateway_rate_limit_counters counters JOIN bounds ON counters.window_start = bounds.window_start WHERE scope_type = $1 @@ -149,12 +172,26 @@ WHERE scope_type = $1 reservation.ScopeKey, reservation.Metric, reservation.WindowSeconds, - ).Scan(&resetAt) + ).Scan(¤tUsed, ¤tReserved, &resetAt) + current := currentUsed + currentReserved return RateLimitReservation{}, &RateLimitExceededError{ - Metric: reservation.Metric, - Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric), - RetryAfter: retryAfterUntil(resetAt), - Retryable: true, + ScopeType: reservation.ScopeType, + ScopeKey: reservation.ScopeKey, + ScopeName: reservation.ScopeName, + ScopeMetadata: reservation.ScopeMetadata, + Metric: reservation.Metric, + Limit: reservation.Limit, + Amount: reservation.Amount, + Current: current, + Used: currentUsed, + Reserved: currentReserved, + Projected: current + reservation.Amount, + WindowSeconds: reservation.WindowSeconds, + ResetAt: resetAt, + Policy: reservation.Policy, + Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric), + RetryAfter: retryAfterUntil(resetAt), + Retryable: true, } } return RateLimitReservation{}, err diff --git a/apps/api/internal/store/runtime_types.go b/apps/api/internal/store/runtime_types.go index d2784a7..601be8e 100644 --- a/apps/api/internal/store/runtime_types.go +++ b/apps/api/internal/store/runtime_types.go @@ -33,10 +33,23 @@ func ModelCandidateErrorCode(err error) string { } type RateLimitExceededError struct { - Metric string - Message string - RetryAfter time.Duration - Retryable bool + ScopeType string + ScopeKey string + ScopeName string + ScopeMetadata map[string]any + Metric string + Limit float64 + Amount float64 + Current float64 + Used float64 + Reserved float64 + Projected float64 + WindowSeconds int + ResetAt time.Time + Policy map[string]any + Message string + RetryAfter time.Duration + Retryable bool } func (e *RateLimitExceededError) Error() string { @@ -166,12 +179,15 @@ type RateLimitReservation struct { ReservationID string ScopeType string ScopeKey string + ScopeName string + ScopeMetadata map[string]any Metric string Limit float64 Amount float64 WindowSeconds int LeaseTTLSeconds int WindowStart time.Time + Policy map[string]any } type RateLimitResult struct { diff --git a/apps/api/internal/store/user_group_policy.go b/apps/api/internal/store/user_group_policy.go index b3f019a..efd347b 100644 --- a/apps/api/internal/store/user_group_policy.go +++ b/apps/api/internal/store/user_group_policy.go @@ -10,6 +10,7 @@ import ( type UserGroupPolicy struct { ID string GroupKey string + Name string RateLimitPolicy map[string]any BillingDiscountPolicy map[string]any } @@ -23,12 +24,12 @@ func (s *Store) ResolveUserGroupPolicy(ctx context.Context, user *auth.User) (Us var rateLimit []byte var billing []byte err := s.pool.QueryRow(ctx, ` -SELECT id::text, group_key, rate_limit_policy, billing_discount_policy +SELECT id::text, group_key, name, rate_limit_policy, billing_discount_policy FROM gateway_user_groups WHERE status = 'active' AND (($1 <> '' AND id = NULLIF($1, '')::uuid) OR ($1 = '' AND group_key = 'default')) ORDER BY CASE WHEN id::text = $1 THEN 0 ELSE 1 END, priority ASC -LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &rateLimit, &billing) +LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &item.Name, &rateLimit, &billing) if err != nil { if err == pgx.ErrNoRows { return UserGroupPolicy{}, nil diff --git a/apps/api/migrations/0036_file_storage_channels.sql b/apps/api/migrations/0036_file_storage_channels.sql new file mode 100644 index 0000000..85c7466 --- /dev/null +++ b/apps/api/migrations/0036_file_storage_channels.sql @@ -0,0 +1,85 @@ +CREATE TABLE IF NOT EXISTS system_settings ( + setting_key text PRIMARY KEY, + value jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now() +); + +INSERT INTO system_settings (setting_key, value) +VALUES ( + 'file_storage', + '{"resultUploadPolicy": "default"}'::jsonb +) +ON CONFLICT (setting_key) DO NOTHING; + +CREATE TABLE IF NOT EXISTS file_storage_channels ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + channel_key text NOT NULL UNIQUE, + name text NOT NULL, + provider text NOT NULL DEFAULT 'server_main_openapi', + upload_url text, + credentials jsonb NOT NULL DEFAULT '{}'::jsonb, + config jsonb NOT NULL DEFAULT '{}'::jsonb, + retry_policy jsonb NOT NULL DEFAULT '{ + "enabled": true, + "maxRetries": 3, + "backoffSeconds": [60, 120, 180], + "strategy": "exponential" + }'::jsonb, + priority integer NOT NULL DEFAULT 100, + status text NOT NULL DEFAULT 'disabled', + last_error text, + last_failed_at timestamptz, + last_succeeded_at timestamptz, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + deleted_at timestamptz +); + +ALTER TABLE IF EXISTS file_storage_channels + ADD COLUMN IF NOT EXISTS upload_url text, + ADD COLUMN IF NOT EXISTS config jsonb NOT NULL DEFAULT '{}'::jsonb, + ADD COLUMN IF NOT EXISTS retry_policy jsonb NOT NULL DEFAULT '{ + "enabled": true, + "maxRetries": 3, + "backoffSeconds": [60, 120, 180], + "strategy": "exponential" + }'::jsonb, + ADD COLUMN IF NOT EXISTS priority integer NOT NULL DEFAULT 100, + ADD COLUMN IF NOT EXISTS last_error text, + ADD COLUMN IF NOT EXISTS last_failed_at timestamptz, + ADD COLUMN IF NOT EXISTS last_succeeded_at timestamptz, + ADD COLUMN IF NOT EXISTS deleted_at timestamptz; + +CREATE INDEX IF NOT EXISTS idx_file_storage_channels_active + ON file_storage_channels (status, priority, created_at) + WHERE deleted_at IS NULL; + +INSERT INTO file_storage_channels ( + channel_key, + name, + provider, + upload_url, + credentials, + config, + retry_policy, + priority, + status +) +VALUES ( + 'server-main-openapi', + 'server-main OpenAPI', + 'server_main_openapi', + 'http://127.0.0.1:3001/v1/files/upload', + '{}'::jsonb, + '{"scenes": ["upload", "image_result"]}'::jsonb, + '{ + "enabled": true, + "maxRetries": 3, + "backoffSeconds": [60, 120, 180], + "strategy": "exponential" + }'::jsonb, + 100, + 'disabled' +) +ON CONFLICT (channel_key) DO NOTHING; diff --git a/apps/api/migrations/0037_file_storage_settings.sql b/apps/api/migrations/0037_file_storage_settings.sql new file mode 100644 index 0000000..2d2e92a --- /dev/null +++ b/apps/api/migrations/0037_file_storage_settings.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS system_settings ( + setting_key text PRIMARY KEY, + value jsonb NOT NULL DEFAULT '{}'::jsonb, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now() +); + +INSERT INTO system_settings (setting_key, value) +VALUES ( + 'file_storage', + '{"resultUploadPolicy": "default"}'::jsonb +) +ON CONFLICT (setting_key) DO NOTHING; diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index 2d508af..508f0c3 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -2,6 +2,10 @@ import { useEffect, useMemo, useRef, useState, type FormEvent } from 'react'; import type { BaseModelCatalogItem, CatalogProvider, + FileStorageChannel, + FileStorageSettings, + FileStorageSettingsUpdateRequest, + FileStorageChannelUpsertRequest, GatewayAccessRuleBatchRequest, GatewayAccessRule, GatewayAccessRuleUpsertRequest, @@ -34,17 +38,22 @@ import { batchApiKeyAccessRules, createAccessRule, createApiKey, + createFileStorageChannel, createGatewayUser, createPlatform, createTenant, createUserGroup, deleteAccessRule, deleteApiKey, + deleteFileStorageChannel, deleteGatewayUser, deletePlatform, deleteTenant, deleteUserGroup, + GatewayApiError, getHealth, + listFileStorageChannels, + getFileStorageSettings, getNetworkProxyConfig, getRunnerPolicy, getWalletSummary, @@ -78,6 +87,8 @@ import { setUserWalletBalance, type HealthResponse, updateAccessRule, + updateFileStorageChannel, + updateFileStorageSettings, updateGatewayUser, updatePlatform, updatePlatformDynamicPriority, @@ -135,6 +146,8 @@ type DataKey = | 'playgroundModels' | 'modelCatalog' | 'networkProxyConfig' + | 'fileStorageChannels' + | 'fileStorageSettings' | 'platforms' | 'models' | 'providers' @@ -179,6 +192,8 @@ export function App() { }); const [playgroundModels, setPlaygroundModels] = useState([]); const [networkProxyConfig, setNetworkProxyConfig] = useState(null); + const [fileStorageChannels, setFileStorageChannels] = useState([]); + const [fileStorageSettings, setFileStorageSettings] = useState(null); const [providers, setProviders] = useState([]); const [baseModels, setBaseModels] = useState([]); const [pricingRules, setPricingRules] = useState([]); @@ -257,7 +272,8 @@ export function App() { loadedDataKeysRef.current.add('modelRateLimits'); loadedDataKeysRef.current.add('platforms'); }) - .catch(() => { + .catch((err) => { + if (handleAuthExpired(err, token)) return; loadedDataKeysRef.current.delete('modelRateLimits'); loadedDataKeysRef.current.delete('platforms'); }); @@ -296,6 +312,8 @@ export function App() { auditLogs, apiKeys, baseModels, + fileStorageChannels, + fileStorageSettings, modelCatalog, models, networkProxyConfig, @@ -315,7 +333,7 @@ export function App() { users, walletAccounts, walletTransactions, - }), [accessRules, apiKeys, auditLogs, baseModels, modelCatalog, modelRateLimits, modelRateLimitsUpdatedAt, models, networkProxyConfig, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runnerPolicy, runtimePolicySets, taskResult, tasks, tenants, userGroups, users, walletAccounts, walletTransactions]); + }), [accessRules, apiKeys, auditLogs, baseModels, fileStorageChannels, fileStorageSettings, modelCatalog, modelRateLimits, modelRateLimitsUpdatedAt, models, networkProxyConfig, platforms, pricingRuleSets, pricingRules, providers, rateLimitWindows, runnerPolicy, runtimePolicySets, taskResult, tasks, tenants, userGroups, users, walletAccounts, walletTransactions]); async function refresh(nextToken = token) { await ensureRouteData(nextToken, true); @@ -354,6 +372,7 @@ export function App() { requestKeys.forEach((key) => loadedDataKeysRef.current.add(key)); setState('ready'); } catch (err) { + if (handleAuthExpired(err, nextToken)) return; setState('error'); setError(err instanceof Error ? err.message : '加载失败'); } finally { @@ -388,6 +407,12 @@ export function App() { case 'networkProxyConfig': setNetworkProxyConfig(await getNetworkProxyConfig(nextToken)); return; + case 'fileStorageChannels': + setFileStorageChannels((await listFileStorageChannels(nextToken)).items); + return; + case 'fileStorageSettings': + setFileStorageSettings(await getFileStorageSettings(nextToken)); + return; case 'playgroundModels': setPlaygroundModels((await listPlayableModels(nextToken)).items); return; @@ -818,6 +843,53 @@ export function App() { } } + async function saveFileStorageChannel(input: FileStorageChannelUpsertRequest, channelId?: string) { + setCoreState('loading'); + setCoreMessage(''); + try { + const item = channelId + ? await updateFileStorageChannel(token, channelId, input) + : await createFileStorageChannel(token, input); + setFileStorageChannels((current) => [item, ...current.filter((channel) => channel.id !== item.id)]); + setCoreState('ready'); + setCoreMessage(channelId ? '文件存储渠道已更新。' : '文件存储渠道已新增。'); + } catch (err) { + setCoreState('error'); + setCoreMessage(err instanceof Error ? err.message : channelId ? '更新文件存储渠道失败' : '新增文件存储渠道失败'); + throw err; + } + } + + async function saveFileStorageSettings(input: FileStorageSettingsUpdateRequest) { + setCoreState('loading'); + setCoreMessage(''); + try { + const settings = await updateFileStorageSettings(token, input); + setFileStorageSettings(settings); + setCoreState('ready'); + setCoreMessage('文件存储全局策略已更新。'); + } catch (err) { + setCoreState('error'); + setCoreMessage(err instanceof Error ? err.message : '更新文件存储全局策略失败'); + throw err; + } + } + + async function removeFileStorageChannel(channelId: string) { + setCoreState('loading'); + setCoreMessage(''); + try { + await deleteFileStorageChannel(token, channelId); + setFileStorageChannels((current) => current.filter((channel) => channel.id !== channelId)); + setCoreState('ready'); + setCoreMessage('文件存储渠道已删除。'); + } catch (err) { + setCoreState('error'); + setCoreMessage(err instanceof Error ? err.message : '删除文件存储渠道失败'); + throw err; + } + } + async function batchSaveAPIKeyAccessRules(input: GatewayAccessRuleBatchRequest) { setCoreState('loading'); setCoreMessage(''); @@ -856,7 +928,7 @@ export function App() { } } - function signOut() { + function resetAuthenticatedSession() { persistAccessToken(''); setToken(''); loadedDataKeysRef.current = new Set(health ? ['health'] : []); @@ -867,6 +939,7 @@ export function App() { setModelCatalog({ items: [], filters: { capabilities: [], providers: [] }, summary: { modelCount: 0, sourceCount: 0 } }); setPlaygroundModels([]); setNetworkProxyConfig(null); + setFileStorageChannels([]); setProviders([]); setBaseModels([]); setPricingRules([]); @@ -892,6 +965,19 @@ export function App() { setWalletTransactionTotal(0); setWorkspaceTransactionQuery(defaultWorkspaceTransactionQuery()); setCoreMessage(''); + } + + function handleAuthExpired(err: unknown, failedToken: string) { + if (!failedToken || !(err instanceof GatewayApiError) || err.details.status !== 401) return false; + resetAuthenticatedSession(); + setError(''); + setAuthMode('login'); + navigatePath(pathForWorkspaceSection('overview')); + return true; + } + + function signOut() { + resetAuthenticatedSession(); navigatePath('/'); } @@ -1039,6 +1125,7 @@ export function App() { onDeletePricingRuleSet={removePricingRuleSet} onDeleteRuntimePolicySet={removeRuntimePolicySet} onDeleteAccessRule={removeAccessRule} + onDeleteFileStorageChannel={removeFileStorageChannel} onDeleteTenant={removeTenant} onDeleteUser={removeUser} onDeleteUserGroup={removeUserGroup} @@ -1054,6 +1141,8 @@ export function App() { onSaveRuntimePolicySet={saveRuntimePolicySet} onBatchAccessRules={batchSaveAccessRules} onSaveAccessRule={saveAccessRule} + onSaveFileStorageChannel={saveFileStorageChannel} + onSaveFileStorageSettings={saveFileStorageSettings} onSaveTenant={saveTenant} onSaveUser={saveUser} onSetUserWalletBalance={saveUserWalletBalance} @@ -1267,6 +1356,8 @@ function dataKeysForRoute( return ['auditLogs']; case 'accessRules': return ['accessRules', 'userGroups', 'platforms', 'models']; + case 'systemSettings': + return ['fileStorageSettings', 'fileStorageChannels']; default: return []; } diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index eeb407c..22abd77 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -5,6 +5,10 @@ import type { CatalogProvider, CatalogProviderUpsertRequest, CreatedGatewayApiKey, + FileStorageChannel, + FileStorageSettings, + FileStorageSettingsUpdateRequest, + FileStorageChannelUpsertRequest, GatewayAccessRuleBatchRequest, GatewayAccessRule, GatewayAccessRuleUpsertRequest, @@ -15,6 +19,7 @@ import type { GatewayTenant, GatewayTenantUpsertRequest, GatewayNetworkProxyConfig, + GatewayPricingEstimate, GatewayTask, GatewayTaskParamPreprocessingLog, GatewayUser, @@ -601,10 +606,17 @@ export async function createImageGenerationTask( model: string; prompt: string; aspect_ratio?: string; + content?: Array>; count?: number; height?: number; + image?: string | string[]; + image_url?: string | string[]; + image_urls?: string[]; + images?: string[]; n?: number; quality?: string; + referenceImage?: string | string[]; + reference_image?: string | string[]; resolution?: string; runMode?: string; simulation?: boolean; @@ -622,7 +634,26 @@ export async function createImageGenerationTask( export async function createImageEditTask( token: string, - input: { model: string; prompt: string; image?: string; mask?: string; runMode?: string; simulation?: boolean }, + input: { + model: string; + prompt: string; + aspect_ratio?: string; + content?: Array>; + count?: number; + height?: number; + image?: string | string[]; + image_url?: string | string[]; + image_urls?: string[]; + images?: string[]; + mask?: string; + n?: number; + quality?: string; + resolution?: string; + runMode?: string; + simulation?: boolean; + size?: string; + width?: number; + }, ): Promise<{ task: GatewayTask; next: Record }> { return request<{ task: GatewayTask; next: Record }>('/api/v1/images/edits', { body: input, @@ -632,25 +663,83 @@ export async function createImageEditTask( }); } +export type VideoGenerationContentRole = + | 'first_frame' + | 'last_frame' + | 'reference_image' + | 'reference_video' + | 'reference_audio' + | 'digital_human_frame' + | 'reference' + | 'element' + | 'video_feature' + | 'video_base' + | 'shot_prompt'; + +export interface VideoGenerationContent { + type: 'text' | 'image_url' | 'audio_url' | 'video_url' | 'element'; + text?: string; + image_url?: { + url: string; + }; + video_url?: { + url: string; + refer_type?: 'feature' | 'base'; + keep_original_sound?: 'yes' | 'no'; + }; + audio_url?: { + url: string; + }; + role?: VideoGenerationContentRole; + shot_index?: number; + duration?: number; + name?: string; + element?: { + system_element_id?: string; + inline_element?: { + name: string; + description?: string; + frontal_image_url: string; + refer_images: Array<{ url: string; slot_key?: string }>; + tags?: string[]; + }; + }; +} + +export interface VideoGenerationParams { + content: VideoGenerationContent[]; + model: string; + aspect_ratio?: string; + resolution?: string; + duration?: number; + audio_list?: Array<{ + url?: string; + audio_url?: string; + name?: string; + }>; + audio?: boolean; + framespersecond?: number; + watermark?: boolean; + seed?: number; + camerafixed?: boolean; + camera_control?: string; + camera_control_strength?: number; + prompt_extend?: boolean; + size?: string; + task_id?: string; + conversation_id?: string; + histories?: string; + callback_url?: string; + prompt_optimizer?: boolean; + fast_pretreatment?: boolean; + mode?: 'std' | 'pro'; + negative_prompt?: string; + cfg_scale?: number; +} + export async function createVideoGenerationTask( token: string, - input: { - audio?: boolean; - model: string; - prompt: string; - aspect_ratio?: string; - count?: number; - duration?: number; - duration_seconds?: number; - height?: number; - n?: number; - output_audio?: boolean; - resolution?: string; - runMode?: string; - simulation?: boolean; - size?: string; - width?: number; - }, + input: VideoGenerationParams, ): Promise<{ task: GatewayTask; next: Record }> { return request<{ task: GatewayTask; next: Record }>('/api/v1/videos/generations', { body: input, @@ -660,11 +749,46 @@ export async function createVideoGenerationTask( }); } +export interface GatewayFileUploadResponse extends Record { + fileUrl?: string; + file_url?: string; + url?: string; +} + +export async function uploadFileToStorage( + token: string, + file: File, + source = 'ai-gateway-playground', +): Promise { + const form = new FormData(); + form.append('file', file); + form.append('source', source); + + const response = await fetch(`${API_BASE}/v1/files/upload`, { + body: form, + headers: { + Authorization: `Bearer ${token}`, + }, + method: 'POST', + }); + const body = await response.text(); + if (!response.ok) { + throw new GatewayApiError(parseErrorDetails(body, response.status, `Request failed: ${response.status}`)); + } + if (!body) return {}; + try { + const parsed = JSON.parse(body) as unknown; + return recordFromUnknown(parsed) ? (parsed as GatewayFileUploadResponse) : {}; + } catch { + return { url: body }; + } +} + export async function estimatePricing( token: string, input: Record, -): Promise<{ items: unknown[]; resolver: string }> { - return request<{ items: unknown[]; resolver: string }>('/api/v1/pricing/estimate', { +): Promise { + return request('/api/v1/pricing/estimate', { body: input, method: 'POST', token, @@ -758,6 +882,55 @@ export async function getNetworkProxyConfig(token: string): Promise('/api/admin/config/network-proxy', { token }); } +export async function listFileStorageChannels(token: string): Promise> { + return request>('/api/admin/system/file-storage/channels', { token }); +} + +export async function getFileStorageSettings(token: string): Promise { + return request('/api/admin/system/file-storage/settings', { token }); +} + +export async function updateFileStorageSettings( + token: string, + input: FileStorageSettingsUpdateRequest, +): Promise { + return request('/api/admin/system/file-storage/settings', { + body: input, + method: 'PATCH', + token, + }); +} + +export async function createFileStorageChannel( + token: string, + input: FileStorageChannelUpsertRequest, +): Promise { + return request('/api/admin/system/file-storage/channels', { + body: input, + method: 'POST', + token, + }); +} + +export async function updateFileStorageChannel( + token: string, + channelId: string, + input: FileStorageChannelUpsertRequest, +): Promise { + return request(`/api/admin/system/file-storage/channels/${channelId}`, { + body: input, + method: 'PATCH', + token, + }); +} + +export async function deleteFileStorageChannel(token: string, channelId: string): Promise { + await request(`/api/admin/system/file-storage/channels/${channelId}`, { + method: 'DELETE', + token, + }); +} + async function request( path: string, options: { token?: string; auth?: boolean; method?: string; body?: unknown; headers?: Record } = {}, diff --git a/apps/web/src/app-state.ts b/apps/web/src/app-state.ts index 83a567f..6b5e8bb 100644 --- a/apps/web/src/app-state.ts +++ b/apps/web/src/app-state.ts @@ -1,6 +1,8 @@ import type { BaseModelCatalogItem, CatalogProvider, + FileStorageChannel, + FileStorageSettings, GatewayAccessRule, GatewayApiKey, GatewayAuditLog, @@ -27,6 +29,8 @@ export interface ConsoleData { auditLogs: GatewayAuditLog[]; apiKeys: GatewayApiKey[]; baseModels: BaseModelCatalogItem[]; + fileStorageChannels: FileStorageChannel[]; + fileStorageSettings: FileStorageSettings | null; modelCatalog: ModelCatalogResponse; models: PlatformModel[]; networkProxyConfig: GatewayNetworkProxyConfig | null; diff --git a/apps/web/src/pages/AdminPage.tsx b/apps/web/src/pages/AdminPage.tsx index 0866033..3df5d1b 100644 --- a/apps/web/src/pages/AdminPage.tsx +++ b/apps/web/src/pages/AdminPage.tsx @@ -1,8 +1,10 @@ import type { ReactNode } from 'react'; -import { Boxes, Building2, Gauge, History, KeyRound, Route, ServerCog, ShieldCheck, UsersRound, Workflow } from 'lucide-react'; +import { Boxes, Building2, Gauge, History, KeyRound, Route, ServerCog, Settings, ShieldCheck, UsersRound, Workflow } from 'lucide-react'; import type { BaseModelUpsertRequest, CatalogProviderUpsertRequest, + FileStorageChannelUpsertRequest, + FileStorageSettingsUpdateRequest, GatewayAccessRuleBatchRequest, GatewayAccessRuleUpsertRequest, GatewayTenantUpsertRequest, @@ -29,6 +31,7 @@ import { PricingRulesPanel } from './admin/PricingRulesPanel'; import { ProviderManagementPanel } from './admin/ProviderManagementPanel'; import { RealtimeLoadPanel } from './admin/RealtimeLoadPanel'; import { RuntimePoliciesPanel } from './admin/RuntimePoliciesPanel'; +import { SystemSettingsPanel } from './admin/SystemSettingsPanel'; const tabs = [ { value: 'overview', label: '总览', icon: }, @@ -42,6 +45,7 @@ const tabs = [ { value: 'users', label: '用户', icon: }, { value: 'userGroups', label: '用户组', icon: }, { value: 'accessRules', label: '模型权限', icon: }, + { value: 'systemSettings', label: '系统设置', icon: }, { value: 'auditLogs', label: '审计日志', icon: }, ] satisfies Array<{ value: AdminSection; label: string; icon: ReactNode }>; @@ -57,6 +61,7 @@ export function AdminPage(props: { onDeletePricingRuleSet: (ruleSetId: string) => Promise; onDeleteRuntimePolicySet: (policySetId: string) => Promise; onDeleteAccessRule: (ruleId: string) => Promise; + onDeleteFileStorageChannel: (channelId: string) => Promise; onDeleteTenant: (tenantId: string) => Promise; onDeleteUser: (userId: string) => Promise; onDeleteUserGroup: (groupId: string) => Promise; @@ -72,6 +77,8 @@ export function AdminPage(props: { onSaveRunnerPolicy: (input: GatewayRunnerPolicyUpsertRequest) => Promise; onSaveRuntimePolicySet: (input: RuntimePolicySetUpsertRequest, policySetId?: string) => Promise; onSaveAccessRule: (input: GatewayAccessRuleUpsertRequest, ruleId?: string) => Promise; + onSaveFileStorageChannel: (input: FileStorageChannelUpsertRequest, channelId?: string) => Promise; + onSaveFileStorageSettings: (input: FileStorageSettingsUpdateRequest) => Promise; onSaveTenant: (input: GatewayTenantUpsertRequest, tenantId?: string) => Promise; onSaveUser: (input: GatewayUserUpsertRequest, userId?: string) => Promise; onSetUserWalletBalance: (userId: string, input: WalletBalanceAdjustmentRequest) => Promise; @@ -172,6 +179,17 @@ export function AdminPage(props: { {props.section === 'users' && } {props.section === 'userGroups' && } {props.section === 'auditLogs' && } + {props.section === 'systemSettings' && ( + + )} diff --git a/apps/web/src/pages/ApiDocsPage.tsx b/apps/web/src/pages/ApiDocsPage.tsx index 9fb423b..ed99f39 100644 --- a/apps/web/src/pages/ApiDocsPage.tsx +++ b/apps/web/src/pages/ApiDocsPage.tsx @@ -34,6 +34,7 @@ export function ApiDocsPage(props: { onTaskFormChange: (value: TaskForm) => void; }) { const current = docs.find((item) => item.key === props.activeDocSection) ?? docs[0]; + const isFileDoc = current.key === 'files'; const bodyExample = useMemo(() => requestBodyExample(props.taskForm), [props.taskForm]); function handleSubmit(event: FormEvent) { @@ -87,7 +88,7 @@ export function ApiDocsPage(props: {

Header 参数

- + @@ -97,10 +98,19 @@ export function ApiDocsPage(props: {

Body 参数

application/json - - - - + {isFileDoc ? ( + <> + + + + ) : ( + <> + + + + + + )} diff --git a/apps/web/src/pages/PlaygroundPage.tsx b/apps/web/src/pages/PlaygroundPage.tsx index 7fd456a..5a2b7e8 100644 --- a/apps/web/src/pages/PlaygroundPage.tsx +++ b/apps/web/src/pages/PlaygroundPage.tsx @@ -1,28 +1,15 @@ -import { useEffect, useMemo, useRef, useState, type ReactNode } from 'react'; -import { - AssistantRuntimeProvider, - ComposerPrimitive, - ErrorPrimitive, - MessagePrimitive, - ThreadPrimitive, - useMessage, - useMessagePartText, - useLocalRuntime, - useThread, - type ChatModelAdapter, - type ThreadMessage, - type ThreadMessageLike, -} from '@assistant-ui/react'; -import { StreamdownTextPrimitive } from '@assistant-ui/react-streamdown'; -import { cjk } from '@streamdown/cjk'; -import { code } from '@streamdown/code'; -import { math } from '@streamdown/math'; -import { mermaid } from '@streamdown/mermaid'; -import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts'; -import { Bot, ChevronDown, Image as ImageIcon, MessageSquarePlus, Paperclip, Send, Sparkles, Video } from 'lucide-react'; -import { Badge, Button, Select, Textarea } from '../components/ui'; -import { GatewayApiError, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, streamChatCompletionText, taskIsPending } from '../api'; +import { useEffect, useMemo, useRef, useState } from 'react'; +import type { GatewayApiKey, GatewayPricingEstimate, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts'; +import { ArrowUp, ChevronDown, MessageSquarePlus, Settings2, Sparkles } from 'lucide-react'; +import { Badge, Button, FormDialog, Select, Textarea } from '../components/ui'; +import { createImageEditTask, createImageGenerationTask, createVideoGenerationTask, estimatePricing, pollTaskUntilSettled, resolveApiAssetUrl, taskIsPending } from '../api'; import type { PlaygroundMode } from '../types'; +import { + PlaygroundPromptMentionInput, + buildPlaygroundResourceToken, + removeInvalidPlaygroundResourceTokens, + replacePlaygroundResourceTokens, +} from './playground-prompt-mention'; import { defaultMediaGenerationSettings, deriveMediaModelCapabilities, @@ -35,51 +22,47 @@ import { type MediaGenerationSettings, type MediaModelCapabilities, } from './playground-media'; - -type VideoCreateMode = 'text_to_video' | 'first_last_frame' | 'omni_reference'; +import { + ComposerUploadButton as SharedComposerUploadButton, + mediaUploadAccept as sharedMediaUploadAccept, + mediaUploadAcceptForMode as sharedMediaUploadAcceptForMode, + mediaUploadRequestPayload as sharedMediaUploadRequestPayload, + mediaUploadSummaryMessage as sharedMediaUploadSummaryMessage, + mergeMediaUploadsForMode as sharedMergeMediaUploadsForMode, + normalizeFirstLastFrameUploads as sharedNormalizeFirstLastFrameUploads, + PlaygroundReferencePicker, + swapFirstLastFrameUploads as sharedSwapFirstLastFrameUploads, + uploadPlaygroundFiles as sharedUploadPlaygroundFiles, + UploadAttachmentList as SharedUploadAttachmentList, + videoGenerationContentFromPromptAndUploads as sharedVideoGenerationContentFromPromptAndUploads, + allowedMediaUploadKinds as sharedAllowedMediaUploadKinds, + type PlaygroundUpload, + type PlaygroundUploadRole, +} from './playground-upload'; +import { AssistantChatPlayground, clearStoredChatMessages } from './playground-chat'; +import { + ApiKeySelect, + ModeSwitch, + PlaygroundGreeting, + apiKeyNoticeText, + modeOptions, + modelOptionLabel, + placeholderByMode, + resolveSelectedApiKeyId, + videoModeOptions, + type ModelOption, + type VideoCreateMode, +} from './playground-shared'; const MEDIA_RUNS_STORAGE_KEY = 'easyai:playground:media-runs:v1'; const MEDIA_RUNS_STORAGE_LIMIT = 50; -const CHAT_MESSAGES_STORAGE_KEY = 'easyai:playground:chat-messages:v1'; -const CHAT_MESSAGES_STORAGE_LIMIT = 100; -interface StoredChatMessage { - content: string; - createdAt: string; - id: string; - role: 'assistant' | 'user'; -} - -interface ModelOption { - count: number; - label: string; - models: PlatformModel[]; - provider: string; - value: string; -} - -const modeOptions: Array<{ description: string; icon: ReactNode; label: string; value: PlaygroundMode }> = [ - { value: 'chat', label: '大模型对话', description: '对话、推理、结构化输出', icon: }, - { value: 'image', label: '图像生成', description: '文生图、图像编辑参数预览', icon: }, - { value: 'video', label: '视频生成', description: '图生视频、文生视频任务测试', icon: