diff --git a/apps/api/docs/swagger.json b/apps/api/docs/swagger.json index bb534b9..adf0469 100644 --- a/apps/api/docs/swagger.json +++ b/apps/api/docs/swagger.json @@ -4455,7 +4455,8 @@ ], "description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。", "consumes": [ - "application/json" + "application/json", + "multipart/form-data" ], "produces": [ "application/json" @@ -6361,7 +6362,8 @@ ], "description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。", "consumes": [ - "application/json" + "application/json", + "multipart/form-data" ], "produces": [ "application/json" @@ -7399,7 +7401,8 @@ ], "description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。", "consumes": [ - "application/json" + "application/json", + "multipart/form-data" ], "produces": [ "application/json" @@ -11557,4 +11560,4 @@ "in": "header" } } -} \ No newline at end of file +} diff --git a/apps/api/docs/swagger.yaml b/apps/api/docs/swagger.yaml index af8bb84..709ee8a 100644 --- a/apps/api/docs/swagger.yaml +++ b/apps/api/docs/swagger.yaml @@ -5232,6 +5232,7 @@ paths: post: consumes: - application/json + - multipart/form-data description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。 parameters: @@ -6464,6 +6465,7 @@ paths: post: consumes: - application/json + - multipart/form-data description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。 parameters: @@ -7145,6 +7147,7 @@ paths: post: consumes: - application/json + - multipart/form-data description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果,OpenAI-compatible 路径同步返回兼容响应或 SSE 流。 parameters: diff --git a/apps/api/internal/clients/clients_test.go b/apps/api/internal/clients/clients_test.go index 623f0dc..5750624 100644 --- a/apps/api/internal/clients/clients_test.go +++ b/apps/api/internal/clients/clients_test.go @@ -992,7 +992,8 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) { var gotAuth string var gotModel string var gotImage string - var gotSequential string + var gotSequential any + var gotSequentialPresent bool server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotPath = r.URL.Path gotAuth = r.Header.Get("Authorization") @@ -1002,7 +1003,7 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) { } gotModel, _ = body["model"].(string) gotImage, _ = body["image"].(string) - gotSequential, _ = body["sequential_image_generation"].(string) + gotSequential, gotSequentialPresent = body["sequential_image_generation"] _ = json.NewEncoder(w).Encode(map[string]any{ "id": "img-volces-edit", "created": 123, @@ -1036,7 +1037,7 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) { if gotPath != "/images/generations" || gotAuth != "Bearer volces-key" { t.Fatalf("unexpected request path=%s auth=%s", gotPath, gotAuth) } - if gotModel != "doubao-seedream-4-0-250828" || gotImage != "https://example.com/source.png" || gotSequential != "auto" { + if gotModel != "doubao-seedream-4-0-250828" || gotImage != "https://example.com/source.png" || gotSequentialPresent { t.Fatalf("unexpected body model=%s image=%s sequential=%s", gotModel, gotImage, gotSequential) } if response.Result["id"] != "img-volces-edit" { @@ -1044,6 +1045,105 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) { } } +func TestVolcesClientImageEditEnablesSequentialForRequestedMultipleImages(t *testing.T) { + var gotSequential string + var gotMaxImages float64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request: %v", err) + } + gotSequential, _ = body["sequential_image_generation"].(string) + options, _ := body["sequential_image_generation_options"].(map[string]any) + gotMaxImages = numericValue(options["max_images"], 0) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "img-volces-edit", + "data": []any{map[string]any{"url": "https://example.com/out.png"}}, + }) + })) + defer server.Close() + + _, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ + Kind: "images.edits", + ModelType: "image_edit", + Body: map[string]any{ + "prompt": "make variants", + "image": "https://example.com/source.png", + "n": 3, + }, + Candidate: store.RuntimeModelCandidate{ + BaseURL: server.URL, + ProviderModelName: "doubao-seedream-4-0-250828", + Credentials: map[string]any{"apiKey": "volces-key"}, + Capabilities: map[string]any{ + "image_edit": map[string]any{ + "output_multiple_images": true, + "output_max_images_count": 4, + }, + }, + }, + }) + if err != nil { + t.Fatalf("run volces image edit: %v", err) + } + if gotSequential != "auto" || gotMaxImages != 3 { + t.Fatalf("unexpected sequential settings mode=%s max_images=%v", gotSequential, gotMaxImages) + } +} + +func TestVolcesClientImageEditPreservesExplicitSequentialDisabledAndClampsMaxImages(t *testing.T) { + var gotSequential string + var gotMaxImages float64 + var gotOtherOption string + options := map[string]any{"max_images": 99, "other": "keep"} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Fatalf("decode request: %v", err) + } + gotSequential, _ = body["sequential_image_generation"].(string) + gotOptions, _ := body["sequential_image_generation_options"].(map[string]any) + gotMaxImages = numericValue(gotOptions["max_images"], 0) + gotOtherOption, _ = gotOptions["other"].(string) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "img-volces-edit", + "data": []any{map[string]any{"url": "https://example.com/out.png"}}, + }) + })) + defer server.Close() + + _, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ + Kind: "images.edits", + ModelType: "image_edit", + Body: map[string]any{ + "prompt": "make variants", + "image": "https://example.com/source.png", + "sequential_image_generation": "disabled", + "sequential_image_generation_options": options, + }, + Candidate: store.RuntimeModelCandidate{ + BaseURL: server.URL, + ProviderModelName: "doubao-seedream-4-0-250828", + Credentials: map[string]any{"apiKey": "volces-key"}, + Capabilities: map[string]any{ + "image_edit": map[string]any{ + "output_multiple_images": true, + "output_max_images_count": 4, + }, + }, + }, + }) + if err != nil { + t.Fatalf("run volces image edit: %v", err) + } + if gotSequential != "disabled" || gotMaxImages != 4 || gotOtherOption != "keep" { + t.Fatalf("unexpected sequential settings mode=%s max_images=%v other=%s", gotSequential, gotMaxImages, gotOtherOption) + } + if numericValue(options["max_images"], 0) != 99 { + t.Fatalf("request options should not be mutated: %+v", options) + } +} + func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { var submitPath string var pollPath string diff --git a/apps/api/internal/clients/volces.go b/apps/api/internal/clients/volces.go index b03c4d3..a739f2d 100644 --- a/apps/api/internal/clients/volces.go +++ b/apps/api/internal/clients/volces.go @@ -208,9 +208,7 @@ func volcesImageBody(request Request) map[string]any { if size := widthHeightSize(body); size != "" { body["size"] = size } - if supportsMultipleOutputs(request, request.ModelType) && body["sequential_image_generation"] == nil { - body["sequential_image_generation"] = "auto" - } + normalizeVolcesSequentialImageGeneration(body, request) return body } @@ -772,6 +770,119 @@ func supportsMultipleOutputs(request Request, capabilityName string) bool { return false } +func normalizeVolcesSequentialImageGeneration(body map[string]any, request Request) { + options, hasOptions := volcesSequentialImageGenerationOptions(body) + normalizeVolcesSequentialMaxImages(options, request) + if _, explicit := body["sequential_image_generation"]; explicit { + return + } + if !supportsMultipleOutputs(request, request.ModelType) { + return + } + count := requestedVolcesSequentialImageCount(body, options) + if count <= 1 { + return + } + body["sequential_image_generation"] = "auto" + if !hasOptions { + options = map[string]any{} + body["sequential_image_generation_options"] = options + } + if _, ok := options["max_images"]; !ok { + options["max_images"] = count + normalizeVolcesSequentialMaxImages(options, request) + } +} + +func volcesSequentialImageGenerationOptions(body map[string]any) (map[string]any, bool) { + options, ok := body["sequential_image_generation_options"].(map[string]any) + if !ok { + return nil, false + } + copied := map[string]any{} + for key, value := range options { + copied[key] = value + } + body["sequential_image_generation_options"] = copied + return copied, true +} + +func normalizeVolcesSequentialMaxImages(options map[string]any, request Request) { + if options == nil { + return + } + raw, exists := options["max_images"] + if !exists { + return + } + current := int(math.Round(numericValue(raw, 0))) + minCount, maxCount := volcesSequentialMaxImagesRange(request) + adjusted := current + if minCount > 0 && adjusted < minCount { + adjusted = minCount + } + if maxCount > 0 && adjusted > maxCount { + adjusted = maxCount + } + if adjusted != current { + options["max_images"] = adjusted + } +} + +func requestedVolcesSequentialImageCount(body map[string]any, options map[string]any) int { + for _, value := range []any{ + valueFromMap(options, "max_images"), + body["count"], + body["n"], + body["batch_size"], + } { + count := int(math.Round(numericValue(value, 0))) + if count > 0 { + return count + } + } + return 1 +} + +func valueFromMap(values map[string]any, key string) any { + if values == nil { + return nil + } + return values[key] +} + +func volcesSequentialMaxImagesRange(request Request) (int, int) { + minCount := 1 + maxCount := 0 + for _, key := range []string{request.ModelType, "image_generate", "image_edit"} { + if key == "" { + continue + } + capability, _ := request.Candidate.Capabilities[key].(map[string]any) + if capability == nil { + continue + } + if value := firstPositiveInt(capability, "output_min_images_count", "outputMinImagesCount", "min_output_images", "minOutputImages", "min_images", "minImages"); value > 0 { + minCount = value + } + if value := firstPositiveInt(capability, "output_max_images_count", "outputMaxImagesCount", "max_output_images", "maxOutputImages", "max_images", "maxImages"); value > 0 { + maxCount = value + } + return minCount, maxCount + } + return minCount, maxCount +} + +func firstPositiveInt(values map[string]any, keys ...string) int { + for _, key := range keys { + value := int(math.Round(numericValue(values[key], 0))) + if value > 0 { + return value + } + } + return 0 +} + func widthHeightSize(body map[string]any) string { width := numericValue(body["width"], 0) height := numericValue(body["height"], 0) diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index f19d21c..7a844d0 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -988,9 +988,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { return } - var body map[string]any - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - writeError(w, http.StatusBadRequest, "invalid json body") + body, err := s.decodeTaskRequestBody(r.Context(), w, r, kind) + if err != nil { + status := http.StatusBadRequest + if code := clients.ErrorCode(err); strings.HasPrefix(code, "upload_") || code == "request_asset_upload_failed" || code == "request_asset_public_url_required" { + status = http.StatusBadGateway + } + writeError(w, status, err.Error(), clients.ErrorCode(err)) return } model, _ := body["model"].(string) diff --git a/apps/api/internal/httpapi/openapi_models.go b/apps/api/internal/httpapi/openapi_models.go index fb73380..f82ac91 100644 --- a/apps/api/internal/httpapi/openapi_models.go +++ b/apps/api/internal/httpapi/openapi_models.go @@ -237,15 +237,16 @@ type ImageGenerationRequest struct { } type ImageEditRequest struct { - Model string `json:"model" example:"gpt-image-1"` - Prompt string `json:"prompt" example:"Add a sunset background"` - Image string `json:"image,omitempty" example:"https://example.com/image.png"` - Mask string `json:"mask,omitempty" example:"https://example.com/mask.png"` - N int `json:"n,omitempty" example:"1"` - Size string `json:"size,omitempty" example:"1024x1024"` - Quality string `json:"quality,omitempty" example:"auto"` - ResponseFormat string `json:"response_format,omitempty" example:"url"` - RunMode string `json:"runMode,omitempty" example:"simulation"` + Model string `json:"model" example:"gpt-image-1"` + Prompt string `json:"prompt" example:"Add a sunset background"` + Image string `json:"image,omitempty" example:"https://example.com/image.png"` + Images []string `json:"images,omitempty" example:"https://example.com/image-a.png,https://example.com/image-b.png"` + Mask string `json:"mask,omitempty" example:"https://example.com/mask.png"` + N int `json:"n,omitempty" example:"1"` + Size string `json:"size,omitempty" example:"1024x1024"` + Quality string `json:"quality,omitempty" example:"auto"` + ResponseFormat string `json:"response_format,omitempty" example:"url"` + RunMode string `json:"runMode,omitempty" example:"simulation"` } type VideoGenerationRequest struct { diff --git a/apps/api/internal/httpapi/request_preparation.go b/apps/api/internal/httpapi/request_preparation.go index 22ddd06..ba01fe1 100644 --- a/apps/api/internal/httpapi/request_preparation.go +++ b/apps/api/internal/httpapi/request_preparation.go @@ -35,6 +35,12 @@ type decodedRequestAsset struct { ContentType string } +type requestAssetOptions struct { + RequirePublicURL bool + UploadScene string + Source string +} + func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) { preparedBody, err := s.prepareRequestAssetRefs(ctx, body) if err != nil { @@ -185,6 +191,21 @@ func requestAssetFromValue(key string, path []string, value any, siblings map[st } func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) { + return s.ensureRequestAssetWithOptions(ctx, decoded, requestAssetOptions{ + UploadScene: store.FileStorageSceneRequestAsset, + Source: "ai-gateway-request", + }) +} + +func (s *Server) ensurePublicRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) { + return s.ensureRequestAssetWithOptions(ctx, decoded, requestAssetOptions{ + RequirePublicURL: true, + UploadScene: store.FileStorageSceneUpload, + Source: "ai-gateway-form-data", + }) +} + +func (s *Server) ensureRequestAssetWithOptions(ctx context.Context, decoded decodedRequestAsset, options requestAssetOptions) (map[string]any, error) { sum := sha256.Sum256(decoded.Bytes) sha := hex.EncodeToString(sum[:]) contentType := strings.TrimSpace(decoded.ContentType) @@ -195,18 +216,29 @@ func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestA if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { return nil, err } else if ok && requestAssetStillUsable(existing, now) { - if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { - return nil, err + ref := requestAssetRef(existing) + if !options.RequirePublicURL || requestAssetRefHasPublicURL(ref) { + if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { + return nil, err + } + return ref, nil } - return requestAssetRef(existing), nil } + uploadScene := strings.TrimSpace(options.UploadScene) + if uploadScene == "" { + uploadScene = store.FileStorageSceneRequestAsset + } + source := strings.TrimSpace(options.Source) + if source == "" { + source = "ai-gateway-request" + } upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{ Bytes: decoded.Bytes, ContentType: contentType, FileName: requestAssetFileName(sha, contentType), - Scene: store.FileStorageSceneRequestAsset, - Source: "ai-gateway-request", + Scene: uploadScene, + Source: source, }) if err != nil { return nil, err @@ -216,6 +248,9 @@ func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestA if url == "" { return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false} } + if options.RequirePublicURL && !requestAssetURLIsPublic(storageProvider, url) { + return nil, &clients.ClientError{Code: "request_asset_public_url_required", Message: "multipart image assets require a public file storage URL; enable a non-local file storage channel for uploads", Retryable: false} + } var expiresAt *time.Time localPath := "" if storageProvider == "local_static" { @@ -340,6 +375,18 @@ func requestAssetRef(asset store.RequestAsset) map[string]any { } } +func requestAssetRefHasPublicURL(ref map[string]any) bool { + return requestAssetURLIsPublic(stringFromRequestAny(ref["storageProvider"]), stringFromRequestAny(ref["url"])) +} + +func requestAssetURLIsPublic(storageProvider string, url string) bool { + if strings.EqualFold(strings.TrimSpace(storageProvider), "local_static") { + return false + } + lower := strings.ToLower(strings.TrimSpace(url)) + return strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") +} + func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool { if asset.ExpiredAt != nil { return false diff --git a/apps/api/internal/httpapi/request_preparation_test.go b/apps/api/internal/httpapi/request_preparation_test.go index dd8bda3..fbb5249 100644 --- a/apps/api/internal/httpapi/request_preparation_test.go +++ b/apps/api/internal/httpapi/request_preparation_test.go @@ -1,10 +1,12 @@ package httpapi import ( + "bytes" "context" "encoding/base64" "io" "log/slog" + "mime/multipart" "net/http" "net/http/httptest" "os" @@ -75,6 +77,87 @@ func TestCanonicalConversationMessageHashUsesTextAndAssetRefs(t *testing.T) { } } +func TestImageEditMultipartFormBodyMapsFilesAndFields(t *testing.T) { + var raw bytes.Buffer + writer := multipart.NewWriter(&raw) + if err := writer.WriteField("model", "doubao-5.0图像编辑"); err != nil { + t.Fatalf("write model field: %v", err) + } + if err := writer.WriteField("prompt", "换个姿势"); err != nil { + t.Fatalf("write prompt field: %v", err) + } + if err := writer.WriteField("n", "2"); err != nil { + t.Fatalf("write n field: %v", err) + } + if err := writer.WriteField("sequential_image_generation_options", `{"max_images":2}`); err != nil { + t.Fatalf("write sequential options field: %v", err) + } + writeMultipartFixtureFile(t, writer, "image", "single.png") + writeMultipartFixtureFile(t, writer, "images", "ref-a.png") + writeMultipartFixtureFile(t, writer, "images[]", "ref-b.png") + writeMultipartFixtureFile(t, writer, "mask", "mask.png") + if err := writer.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/api/v1/images/edits", &raw) + request.Header.Set("Content-Type", writer.FormDataContentType()) + if err := request.ParseMultipartForm(multipartTaskMemoryBytes); err != nil { + t.Fatalf("parse multipart form: %v", err) + } + defer request.MultipartForm.RemoveAll() + + body, err := imageEditMultipartFormBody(context.Background(), request.MultipartForm, func(_ context.Context, field string, header *multipart.FileHeader) (map[string]any, error) { + ref := map[string]any{ + "sha256": field + "-" + header.Filename, + "url": "https://cdn.example/" + header.Filename, + "contentType": header.Header.Get("Content-Type"), + "storageProvider": "server_main_openapi", + } + return requestAssetWrapper(ref), nil + }) + if err != nil { + t.Fatalf("build multipart image edit body: %v", err) + } + if body["model"] != "doubao-5.0图像编辑" || body["prompt"] != "换个姿势" { + t.Fatalf("unexpected scalar fields: %+v", body) + } + if body["n"] != float64(2) { + t.Fatalf("n should be parsed as number, got %#v", body["n"]) + } + options, _ := body["sequential_image_generation_options"].(map[string]any) + if options["max_images"] != float64(2) { + t.Fatalf("sequential options should parse JSON object, got %+v", options) + } + image, _ := body["image"].(map[string]any) + if image["url"] != "https://cdn.example/single.png" { + t.Fatalf("single image should map to image URL wrapper, got %+v", image) + } + images, _ := body["images"].([]any) + if len(images) != 2 { + t.Fatalf("multi image fields should map to images array, got %+v", body["images"]) + } + firstMulti, _ := images[0].(map[string]any) + secondMulti, _ := images[1].(map[string]any) + if firstMulti["url"] != "https://cdn.example/ref-a.png" || secondMulti["url"] != "https://cdn.example/ref-b.png" { + t.Fatalf("unexpected images array: %+v", images) + } + mask, _ := body["mask"].(map[string]any) + if mask["url"] != "https://cdn.example/mask.png" { + t.Fatalf("mask should map to mask URL wrapper, got %+v", mask) + } +} + +func writeMultipartFixtureFile(t *testing.T, writer *multipart.Writer, field string, filename string) { + t.Helper() + part, err := writer.CreateFormFile(field, filename) + if err != nil { + t.Fatalf("create multipart file %s/%s: %v", field, filename, err) + } + if _, err := part.Write([]byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'}); err != nil { + t.Fatalf("write multipart file %s/%s: %v", field, filename, err) + } +} + func TestCleanupExpiredLocalTempAssetsDeletesExpiredStaticFiles(t *testing.T) { uploadedDir := t.TempDir() generatedDir := t.TempDir() diff --git a/apps/api/internal/httpapi/task_multipart.go b/apps/api/internal/httpapi/task_multipart.go new file mode 100644 index 0000000..d4e1949 --- /dev/null +++ b/apps/api/internal/httpapi/task_multipart.go @@ -0,0 +1,290 @@ +package httpapi + +import ( + "context" + "encoding/json" + "io" + "mime" + "mime/multipart" + "net/http" + "strconv" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" +) + +const multipartTaskMemoryBytes = 32 << 20 + +type imageEditMultipartAssetUploader func(context.Context, string, *multipart.FileHeader) (map[string]any, error) + +func (s *Server) decodeTaskRequestBody(ctx context.Context, w http.ResponseWriter, r *http.Request, kind string) (map[string]any, error) { + if requestIsMultipartForm(r) { + if kind != "images.edits" { + return nil, &clients.ClientError{Code: "unsupported_multipart_body", Message: "multipart/form-data is only supported for image edit tasks", Retryable: false} + } + return s.decodeImageEditMultipartBody(ctx, w, r) + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + return nil, &clients.ClientError{Code: "invalid_json_body", Message: "invalid json body", Retryable: false} + } + if body == nil { + body = map[string]any{} + } + return body, nil +} + +func requestIsMultipartForm(r *http.Request) bool { + contentType := strings.TrimSpace(r.Header.Get("Content-Type")) + if contentType == "" { + return false + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return strings.HasPrefix(strings.ToLower(contentType), "multipart/form-data") + } + return strings.EqualFold(mediaType, "multipart/form-data") +} + +func (s *Server) decodeImageEditMultipartBody(ctx context.Context, w http.ResponseWriter, r *http.Request) (map[string]any, error) { + r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes) + if err := r.ParseMultipartForm(multipartTaskMemoryBytes); err != nil { + return nil, &clients.ClientError{Code: "invalid_multipart_body", Message: "invalid multipart form-data body", Retryable: false} + } + if r.MultipartForm == nil { + return map[string]any{}, nil + } + defer r.MultipartForm.RemoveAll() + return imageEditMultipartFormBody(ctx, r.MultipartForm, s.uploadImageEditMultipartAsset) +} + +func imageEditMultipartFormBody(ctx context.Context, form *multipart.Form, upload imageEditMultipartAssetUploader) (map[string]any, error) { + body := map[string]any{} + if form == nil { + return body, nil + } + for key, values := range form.Value { + addImageEditMultipartFieldValues(body, key, values) + } + if upload == nil { + return body, nil + } + if err := addImageEditMultipartFiles(ctx, body, form.File, upload); err != nil { + return nil, err + } + return body, nil +} + +func addImageEditMultipartFieldValues(body map[string]any, rawKey string, values []string) { + key := normalizeImageEditMultipartFieldName(rawKey) + parsed := make([]any, 0, len(values)) + for _, value := range values { + if strings.TrimSpace(value) == "" { + continue + } + parsed = append(parsed, parseImageEditMultipartFieldValue(key, value)) + } + if len(parsed) == 0 { + return + } + switch key { + case "image": + if len(parsed) == 1 { + body["image"] = parsed[0] + return + } + appendImageEditMultipartList(body, "images", parsed...) + case "images": + appendImageEditMultipartList(body, "images", flattenImageEditMultipartValues(parsed)...) + case "mask": + body["mask"] = parsed[0] + default: + if len(parsed) == 1 { + body[key] = parsed[0] + } else { + body[key] = parsed + } + } +} + +func normalizeImageEditMultipartFieldName(key string) string { + switch strings.TrimSpace(key) { + case "Image": + return "image" + case "images", "images[]", "image[]", "files": + return "images" + default: + return strings.TrimSpace(key) + } +} + +func parseImageEditMultipartFieldValue(key string, value string) any { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if parsed, ok := parseImageEditMultipartJSONValue(trimmed); ok { + return parsed + } + if isImageEditMultipartBooleanField(key) { + if parsed, err := strconv.ParseBool(trimmed); err == nil { + return parsed + } + } + if isImageEditMultipartNumberField(key) { + if parsed, err := strconv.ParseFloat(trimmed, 64); err == nil { + return parsed + } + } + return trimmed +} + +func parseImageEditMultipartJSONValue(value string) (any, bool) { + if value == "" { + return nil, false + } + switch value[0] { + case '{', '[', '"': + default: + return nil, false + } + var parsed any + if err := json.Unmarshal([]byte(value), &parsed); err != nil { + return nil, false + } + return parsed, true +} + +func isImageEditMultipartBooleanField(key string) bool { + switch key { + case "stream", "simulation", "testMode", "test_mode", "watermark", "sync": + return true + default: + return false + } +} + +func isImageEditMultipartNumberField(key string) bool { + switch key { + case "n", "count", "width", "height", "seed", "batch_size", "batchSize", "simulationDurationMs", "simulation_duration_ms", "duration": + return true + default: + return false + } +} + +func addImageEditMultipartFiles(ctx context.Context, body map[string]any, files map[string][]*multipart.FileHeader, upload imageEditMultipartAssetUploader) error { + imageFiles := collectImageEditMultipartFiles(files, "image", "Image") + if len(imageFiles) == 1 { + value, err := upload(ctx, "image", imageFiles[0]) + if err != nil { + return err + } + body["image"] = value + } else if len(imageFiles) > 1 { + values, err := uploadImageEditMultipartFiles(ctx, "images", imageFiles, upload) + if err != nil { + return err + } + appendImageEditMultipartList(body, "images", values...) + } + multiImageFiles := collectImageEditMultipartFiles(files, "images", "images[]", "image[]", "files") + if len(multiImageFiles) > 0 { + values, err := uploadImageEditMultipartFiles(ctx, "images", multiImageFiles, upload) + if err != nil { + return err + } + appendImageEditMultipartList(body, "images", values...) + } + maskFiles := collectImageEditMultipartFiles(files, "mask") + if len(maskFiles) > 0 { + value, err := upload(ctx, "mask", maskFiles[0]) + if err != nil { + return err + } + body["mask"] = value + } + return nil +} + +func collectImageEditMultipartFiles(files map[string][]*multipart.FileHeader, keys ...string) []*multipart.FileHeader { + out := make([]*multipart.FileHeader, 0) + for _, key := range keys { + out = append(out, files[key]...) + } + return out +} + +func uploadImageEditMultipartFiles(ctx context.Context, field string, headers []*multipart.FileHeader, upload imageEditMultipartAssetUploader) ([]any, error) { + values := make([]any, 0, len(headers)) + for _, header := range headers { + value, err := upload(ctx, field, header) + if err != nil { + return nil, err + } + values = append(values, value) + } + return values, nil +} + +func (s *Server) uploadImageEditMultipartAsset(ctx context.Context, field string, header *multipart.FileHeader) (map[string]any, error) { + file, err := header.Open() + if err != nil { + return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false} + } + defer file.Close() + payload, err := io.ReadAll(file) + if err != nil { + return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false} + } + contentType := strings.TrimSpace(header.Header.Get("Content-Type")) + detectedContentType := "" + if len(payload) > 0 { + detectedContentType = http.DetectContentType(payload) + } + if contentType != "" && !strings.HasPrefix(strings.ToLower(contentType), "image/") && !strings.HasPrefix(strings.ToLower(detectedContentType), "image/") { + return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false} + } + contentType = requestAssetContentType(contentType, payload, field, []string{field}, nil) + if !strings.HasPrefix(strings.ToLower(contentType), "image/") { + return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false} + } + ref, err := s.ensurePublicRequestAsset(ctx, decodedRequestAsset{ + Bytes: payload, + ContentType: contentType, + }) + if err != nil { + return nil, err + } + return requestAssetWrapper(ref), nil +} + +func appendImageEditMultipartList(body map[string]any, key string, values ...any) { + list := flattenImageEditMultipartValues([]any{body[key]}) + list = append(list, flattenImageEditMultipartValues(values)...) + if len(list) == 0 { + return + } + body[key] = list +} + +func flattenImageEditMultipartValues(values []any) []any { + out := make([]any, 0, len(values)) + for _, value := range values { + switch typed := value.(type) { + case nil: + continue + case []any: + out = append(out, flattenImageEditMultipartValues(typed)...) + case []string: + for _, item := range typed { + if text := strings.TrimSpace(item); text != "" { + out = append(out, text) + } + } + default: + out = append(out, value) + } + } + return out +}