Add multipart image edit support

This commit is contained in:
wangbo 2026-06-08 01:17:42 +08:00
parent b7500d81d1
commit 679bfeb9c9
9 changed files with 669 additions and 27 deletions

View File

@ -4455,7 +4455,8 @@
], ],
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。", "description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
"consumes": [ "consumes": [
"application/json" "application/json",
"multipart/form-data"
], ],
"produces": [ "produces": [
"application/json" "application/json"
@ -6361,7 +6362,8 @@
], ],
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。", "description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
"consumes": [ "consumes": [
"application/json" "application/json",
"multipart/form-data"
], ],
"produces": [ "produces": [
"application/json" "application/json"
@ -7399,7 +7401,8 @@
], ],
"description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。", "description": "网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible 路径同步返回兼容响应或 SSE 流。",
"consumes": [ "consumes": [
"application/json" "application/json",
"multipart/form-data"
], ],
"produces": [ "produces": [
"application/json" "application/json"
@ -11557,4 +11560,4 @@
"in": "header" "in": "header"
} }
} }
} }

View File

@ -5232,6 +5232,7 @@ paths:
post: post:
consumes: consumes:
- application/json - application/json
- multipart/form-data
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible
路径同步返回兼容响应或 SSE 流。 路径同步返回兼容响应或 SSE 流。
parameters: parameters:
@ -6464,6 +6465,7 @@ paths:
post: post:
consumes: consumes:
- application/json - application/json
- multipart/form-data
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible
路径同步返回兼容响应或 SSE 流。 路径同步返回兼容响应或 SSE 流。
parameters: parameters:
@ -7145,6 +7147,7 @@ paths:
post: post:
consumes: consumes:
- application/json - application/json
- multipart/form-data
description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible description: 网关任务接口按 model 选择平台模型;除 /api/v1/chat/completions 以外的 /api/v1 任务路径返回任务受理结果OpenAI-compatible
路径同步返回兼容响应或 SSE 流。 路径同步返回兼容响应或 SSE 流。
parameters: parameters:

View File

@ -992,7 +992,8 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) {
var gotAuth string var gotAuth string
var gotModel string var gotModel string
var gotImage string var gotImage string
var gotSequential string var gotSequential any
var gotSequentialPresent bool
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization") gotAuth = r.Header.Get("Authorization")
@ -1002,7 +1003,7 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) {
} }
gotModel, _ = body["model"].(string) gotModel, _ = body["model"].(string)
gotImage, _ = body["image"].(string) gotImage, _ = body["image"].(string)
gotSequential, _ = body["sequential_image_generation"].(string) gotSequential, gotSequentialPresent = body["sequential_image_generation"]
_ = json.NewEncoder(w).Encode(map[string]any{ _ = json.NewEncoder(w).Encode(map[string]any{
"id": "img-volces-edit", "id": "img-volces-edit",
"created": 123, "created": 123,
@ -1036,7 +1037,7 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) {
if gotPath != "/images/generations" || gotAuth != "Bearer volces-key" { if gotPath != "/images/generations" || gotAuth != "Bearer volces-key" {
t.Fatalf("unexpected request path=%s auth=%s", gotPath, gotAuth) t.Fatalf("unexpected request path=%s auth=%s", gotPath, gotAuth)
} }
if gotModel != "doubao-seedream-4-0-250828" || gotImage != "https://example.com/source.png" || gotSequential != "auto" { if gotModel != "doubao-seedream-4-0-250828" || gotImage != "https://example.com/source.png" || gotSequentialPresent {
t.Fatalf("unexpected body model=%s image=%s sequential=%s", gotModel, gotImage, gotSequential) t.Fatalf("unexpected body model=%s image=%s sequential=%s", gotModel, gotImage, gotSequential)
} }
if response.Result["id"] != "img-volces-edit" { if response.Result["id"] != "img-volces-edit" {
@ -1044,6 +1045,105 @@ func TestVolcesClientImageEditUsesGenerationEndpoint(t *testing.T) {
} }
} }
func TestVolcesClientImageEditEnablesSequentialForRequestedMultipleImages(t *testing.T) {
var gotSequential string
var gotMaxImages float64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
gotSequential, _ = body["sequential_image_generation"].(string)
options, _ := body["sequential_image_generation_options"].(map[string]any)
gotMaxImages = numericValue(options["max_images"], 0)
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "img-volces-edit",
"data": []any{map[string]any{"url": "https://example.com/out.png"}},
})
}))
defer server.Close()
_, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "images.edits",
ModelType: "image_edit",
Body: map[string]any{
"prompt": "make variants",
"image": "https://example.com/source.png",
"n": 3,
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ProviderModelName: "doubao-seedream-4-0-250828",
Credentials: map[string]any{"apiKey": "volces-key"},
Capabilities: map[string]any{
"image_edit": map[string]any{
"output_multiple_images": true,
"output_max_images_count": 4,
},
},
},
})
if err != nil {
t.Fatalf("run volces image edit: %v", err)
}
if gotSequential != "auto" || gotMaxImages != 3 {
t.Fatalf("unexpected sequential settings mode=%s max_images=%v", gotSequential, gotMaxImages)
}
}
func TestVolcesClientImageEditPreservesExplicitSequentialDisabledAndClampsMaxImages(t *testing.T) {
var gotSequential string
var gotMaxImages float64
var gotOtherOption string
options := map[string]any{"max_images": 99, "other": "keep"}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("decode request: %v", err)
}
gotSequential, _ = body["sequential_image_generation"].(string)
gotOptions, _ := body["sequential_image_generation_options"].(map[string]any)
gotMaxImages = numericValue(gotOptions["max_images"], 0)
gotOtherOption, _ = gotOptions["other"].(string)
_ = json.NewEncoder(w).Encode(map[string]any{
"id": "img-volces-edit",
"data": []any{map[string]any{"url": "https://example.com/out.png"}},
})
}))
defer server.Close()
_, err := (VolcesClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "images.edits",
ModelType: "image_edit",
Body: map[string]any{
"prompt": "make variants",
"image": "https://example.com/source.png",
"sequential_image_generation": "disabled",
"sequential_image_generation_options": options,
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL,
ProviderModelName: "doubao-seedream-4-0-250828",
Credentials: map[string]any{"apiKey": "volces-key"},
Capabilities: map[string]any{
"image_edit": map[string]any{
"output_multiple_images": true,
"output_max_images_count": 4,
},
},
},
})
if err != nil {
t.Fatalf("run volces image edit: %v", err)
}
if gotSequential != "disabled" || gotMaxImages != 4 || gotOtherOption != "keep" {
t.Fatalf("unexpected sequential settings mode=%s max_images=%v other=%s", gotSequential, gotMaxImages, gotOtherOption)
}
if numericValue(options["max_images"], 0) != 99 {
t.Fatalf("request options should not be mutated: %+v", options)
}
}
func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) { func TestVolcesClientVideoSubmitsAndPollsTask(t *testing.T) {
var submitPath string var submitPath string
var pollPath string var pollPath string

View File

@ -208,9 +208,7 @@ func volcesImageBody(request Request) map[string]any {
if size := widthHeightSize(body); size != "" { if size := widthHeightSize(body); size != "" {
body["size"] = size body["size"] = size
} }
if supportsMultipleOutputs(request, request.ModelType) && body["sequential_image_generation"] == nil { normalizeVolcesSequentialImageGeneration(body, request)
body["sequential_image_generation"] = "auto"
}
return body return body
} }
@ -772,6 +770,119 @@ func supportsMultipleOutputs(request Request, capabilityName string) bool {
return false return false
} }
func normalizeVolcesSequentialImageGeneration(body map[string]any, request Request) {
options, hasOptions := volcesSequentialImageGenerationOptions(body)
normalizeVolcesSequentialMaxImages(options, request)
if _, explicit := body["sequential_image_generation"]; explicit {
return
}
if !supportsMultipleOutputs(request, request.ModelType) {
return
}
count := requestedVolcesSequentialImageCount(body, options)
if count <= 1 {
return
}
body["sequential_image_generation"] = "auto"
if !hasOptions {
options = map[string]any{}
body["sequential_image_generation_options"] = options
}
if _, ok := options["max_images"]; !ok {
options["max_images"] = count
normalizeVolcesSequentialMaxImages(options, request)
}
}
func volcesSequentialImageGenerationOptions(body map[string]any) (map[string]any, bool) {
options, ok := body["sequential_image_generation_options"].(map[string]any)
if !ok {
return nil, false
}
copied := map[string]any{}
for key, value := range options {
copied[key] = value
}
body["sequential_image_generation_options"] = copied
return copied, true
}
func normalizeVolcesSequentialMaxImages(options map[string]any, request Request) {
if options == nil {
return
}
raw, exists := options["max_images"]
if !exists {
return
}
current := int(math.Round(numericValue(raw, 0)))
minCount, maxCount := volcesSequentialMaxImagesRange(request)
adjusted := current
if minCount > 0 && adjusted < minCount {
adjusted = minCount
}
if maxCount > 0 && adjusted > maxCount {
adjusted = maxCount
}
if adjusted != current {
options["max_images"] = adjusted
}
}
func requestedVolcesSequentialImageCount(body map[string]any, options map[string]any) int {
for _, value := range []any{
valueFromMap(options, "max_images"),
body["count"],
body["n"],
body["batch_size"],
} {
count := int(math.Round(numericValue(value, 0)))
if count > 0 {
return count
}
}
return 1
}
func valueFromMap(values map[string]any, key string) any {
if values == nil {
return nil
}
return values[key]
}
func volcesSequentialMaxImagesRange(request Request) (int, int) {
minCount := 1
maxCount := 0
for _, key := range []string{request.ModelType, "image_generate", "image_edit"} {
if key == "" {
continue
}
capability, _ := request.Candidate.Capabilities[key].(map[string]any)
if capability == nil {
continue
}
if value := firstPositiveInt(capability, "output_min_images_count", "outputMinImagesCount", "min_output_images", "minOutputImages", "min_images", "minImages"); value > 0 {
minCount = value
}
if value := firstPositiveInt(capability, "output_max_images_count", "outputMaxImagesCount", "max_output_images", "maxOutputImages", "max_images", "maxImages"); value > 0 {
maxCount = value
}
return minCount, maxCount
}
return minCount, maxCount
}
func firstPositiveInt(values map[string]any, keys ...string) int {
for _, key := range keys {
value := int(math.Round(numericValue(values[key], 0)))
if value > 0 {
return value
}
}
return 0
}
func widthHeightSize(body map[string]any) string { func widthHeightSize(body map[string]any) string {
width := numericValue(body["width"], 0) width := numericValue(body["width"], 0)
height := numericValue(body["height"], 0) height := numericValue(body["height"], 0)

View File

@ -988,9 +988,13 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
return return
} }
var body map[string]any body, err := s.decodeTaskRequestBody(r.Context(), w, r, kind)
if err := json.NewDecoder(r.Body).Decode(&body); err != nil { if err != nil {
writeError(w, http.StatusBadRequest, "invalid json body") status := http.StatusBadRequest
if code := clients.ErrorCode(err); strings.HasPrefix(code, "upload_") || code == "request_asset_upload_failed" || code == "request_asset_public_url_required" {
status = http.StatusBadGateway
}
writeError(w, status, err.Error(), clients.ErrorCode(err))
return return
} }
model, _ := body["model"].(string) model, _ := body["model"].(string)

View File

@ -237,15 +237,16 @@ type ImageGenerationRequest struct {
} }
type ImageEditRequest struct { type ImageEditRequest struct {
Model string `json:"model" example:"gpt-image-1"` Model string `json:"model" example:"gpt-image-1"`
Prompt string `json:"prompt" example:"Add a sunset background"` Prompt string `json:"prompt" example:"Add a sunset background"`
Image string `json:"image,omitempty" example:"https://example.com/image.png"` Image string `json:"image,omitempty" example:"https://example.com/image.png"`
Mask string `json:"mask,omitempty" example:"https://example.com/mask.png"` Images []string `json:"images,omitempty" example:"https://example.com/image-a.png,https://example.com/image-b.png"`
N int `json:"n,omitempty" example:"1"` Mask string `json:"mask,omitempty" example:"https://example.com/mask.png"`
Size string `json:"size,omitempty" example:"1024x1024"` N int `json:"n,omitempty" example:"1"`
Quality string `json:"quality,omitempty" example:"auto"` Size string `json:"size,omitempty" example:"1024x1024"`
ResponseFormat string `json:"response_format,omitempty" example:"url"` Quality string `json:"quality,omitempty" example:"auto"`
RunMode string `json:"runMode,omitempty" example:"simulation"` ResponseFormat string `json:"response_format,omitempty" example:"url"`
RunMode string `json:"runMode,omitempty" example:"simulation"`
} }
type VideoGenerationRequest struct { type VideoGenerationRequest struct {

View File

@ -35,6 +35,12 @@ type decodedRequestAsset struct {
ContentType string ContentType string
} }
type requestAssetOptions struct {
RequirePublicURL bool
UploadScene string
Source string
}
func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) { func (s *Server) prepareTaskRequest(ctx context.Context, r *http.Request, user *auth.User, body map[string]any) (preparedTaskRequest, error) {
preparedBody, err := s.prepareRequestAssetRefs(ctx, body) preparedBody, err := s.prepareRequestAssetRefs(ctx, body)
if err != nil { if err != nil {
@ -185,6 +191,21 @@ func requestAssetFromValue(key string, path []string, value any, siblings map[st
} }
func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) { func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) {
return s.ensureRequestAssetWithOptions(ctx, decoded, requestAssetOptions{
UploadScene: store.FileStorageSceneRequestAsset,
Source: "ai-gateway-request",
})
}
func (s *Server) ensurePublicRequestAsset(ctx context.Context, decoded decodedRequestAsset) (map[string]any, error) {
return s.ensureRequestAssetWithOptions(ctx, decoded, requestAssetOptions{
RequirePublicURL: true,
UploadScene: store.FileStorageSceneUpload,
Source: "ai-gateway-form-data",
})
}
func (s *Server) ensureRequestAssetWithOptions(ctx context.Context, decoded decodedRequestAsset, options requestAssetOptions) (map[string]any, error) {
sum := sha256.Sum256(decoded.Bytes) sum := sha256.Sum256(decoded.Bytes)
sha := hex.EncodeToString(sum[:]) sha := hex.EncodeToString(sum[:])
contentType := strings.TrimSpace(decoded.ContentType) contentType := strings.TrimSpace(decoded.ContentType)
@ -195,18 +216,29 @@ func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestA
if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { if existing, ok, err := s.store.FindRequestAsset(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
return nil, err return nil, err
} else if ok && requestAssetStillUsable(existing, now) { } else if ok && requestAssetStillUsable(existing, now) {
if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) { ref := requestAssetRef(existing)
return nil, err if !options.RequirePublicURL || requestAssetRefHasPublicURL(ref) {
if err := s.store.IncrementRequestAssetRefCount(ctx, sha, contentType); err != nil && !store.IsUndefinedDatabaseObject(err) {
return nil, err
}
return ref, nil
} }
return requestAssetRef(existing), nil
} }
uploadScene := strings.TrimSpace(options.UploadScene)
if uploadScene == "" {
uploadScene = store.FileStorageSceneRequestAsset
}
source := strings.TrimSpace(options.Source)
if source == "" {
source = "ai-gateway-request"
}
upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{ upload, err := s.runner.UploadFile(ctx, runner.FileUploadPayload{
Bytes: decoded.Bytes, Bytes: decoded.Bytes,
ContentType: contentType, ContentType: contentType,
FileName: requestAssetFileName(sha, contentType), FileName: requestAssetFileName(sha, contentType),
Scene: store.FileStorageSceneRequestAsset, Scene: uploadScene,
Source: "ai-gateway-request", Source: source,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -216,6 +248,9 @@ func (s *Server) ensureRequestAsset(ctx context.Context, decoded decodedRequestA
if url == "" { if url == "" {
return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false} return nil, &clients.ClientError{Code: "request_asset_upload_failed", Message: "file storage response did not include url", Retryable: false}
} }
if options.RequirePublicURL && !requestAssetURLIsPublic(storageProvider, url) {
return nil, &clients.ClientError{Code: "request_asset_public_url_required", Message: "multipart image assets require a public file storage URL; enable a non-local file storage channel for uploads", Retryable: false}
}
var expiresAt *time.Time var expiresAt *time.Time
localPath := "" localPath := ""
if storageProvider == "local_static" { if storageProvider == "local_static" {
@ -340,6 +375,18 @@ func requestAssetRef(asset store.RequestAsset) map[string]any {
} }
} }
func requestAssetRefHasPublicURL(ref map[string]any) bool {
return requestAssetURLIsPublic(stringFromRequestAny(ref["storageProvider"]), stringFromRequestAny(ref["url"]))
}
func requestAssetURLIsPublic(storageProvider string, url string) bool {
if strings.EqualFold(strings.TrimSpace(storageProvider), "local_static") {
return false
}
lower := strings.ToLower(strings.TrimSpace(url))
return strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://")
}
func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool { func requestAssetStillUsable(asset store.RequestAsset, now time.Time) bool {
if asset.ExpiredAt != nil { if asset.ExpiredAt != nil {
return false return false

View File

@ -1,10 +1,12 @@
package httpapi package httpapi
import ( import (
"bytes"
"context" "context"
"encoding/base64" "encoding/base64"
"io" "io"
"log/slog" "log/slog"
"mime/multipart"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -75,6 +77,87 @@ func TestCanonicalConversationMessageHashUsesTextAndAssetRefs(t *testing.T) {
} }
} }
func TestImageEditMultipartFormBodyMapsFilesAndFields(t *testing.T) {
var raw bytes.Buffer
writer := multipart.NewWriter(&raw)
if err := writer.WriteField("model", "doubao-5.0图像编辑"); err != nil {
t.Fatalf("write model field: %v", err)
}
if err := writer.WriteField("prompt", "换个姿势"); err != nil {
t.Fatalf("write prompt field: %v", err)
}
if err := writer.WriteField("n", "2"); err != nil {
t.Fatalf("write n field: %v", err)
}
if err := writer.WriteField("sequential_image_generation_options", `{"max_images":2}`); err != nil {
t.Fatalf("write sequential options field: %v", err)
}
writeMultipartFixtureFile(t, writer, "image", "single.png")
writeMultipartFixtureFile(t, writer, "images", "ref-a.png")
writeMultipartFixtureFile(t, writer, "images[]", "ref-b.png")
writeMultipartFixtureFile(t, writer, "mask", "mask.png")
if err := writer.Close(); err != nil {
t.Fatalf("close multipart writer: %v", err)
}
request := httptest.NewRequest(http.MethodPost, "/api/v1/images/edits", &raw)
request.Header.Set("Content-Type", writer.FormDataContentType())
if err := request.ParseMultipartForm(multipartTaskMemoryBytes); err != nil {
t.Fatalf("parse multipart form: %v", err)
}
defer request.MultipartForm.RemoveAll()
body, err := imageEditMultipartFormBody(context.Background(), request.MultipartForm, func(_ context.Context, field string, header *multipart.FileHeader) (map[string]any, error) {
ref := map[string]any{
"sha256": field + "-" + header.Filename,
"url": "https://cdn.example/" + header.Filename,
"contentType": header.Header.Get("Content-Type"),
"storageProvider": "server_main_openapi",
}
return requestAssetWrapper(ref), nil
})
if err != nil {
t.Fatalf("build multipart image edit body: %v", err)
}
if body["model"] != "doubao-5.0图像编辑" || body["prompt"] != "换个姿势" {
t.Fatalf("unexpected scalar fields: %+v", body)
}
if body["n"] != float64(2) {
t.Fatalf("n should be parsed as number, got %#v", body["n"])
}
options, _ := body["sequential_image_generation_options"].(map[string]any)
if options["max_images"] != float64(2) {
t.Fatalf("sequential options should parse JSON object, got %+v", options)
}
image, _ := body["image"].(map[string]any)
if image["url"] != "https://cdn.example/single.png" {
t.Fatalf("single image should map to image URL wrapper, got %+v", image)
}
images, _ := body["images"].([]any)
if len(images) != 2 {
t.Fatalf("multi image fields should map to images array, got %+v", body["images"])
}
firstMulti, _ := images[0].(map[string]any)
secondMulti, _ := images[1].(map[string]any)
if firstMulti["url"] != "https://cdn.example/ref-a.png" || secondMulti["url"] != "https://cdn.example/ref-b.png" {
t.Fatalf("unexpected images array: %+v", images)
}
mask, _ := body["mask"].(map[string]any)
if mask["url"] != "https://cdn.example/mask.png" {
t.Fatalf("mask should map to mask URL wrapper, got %+v", mask)
}
}
func writeMultipartFixtureFile(t *testing.T, writer *multipart.Writer, field string, filename string) {
t.Helper()
part, err := writer.CreateFormFile(field, filename)
if err != nil {
t.Fatalf("create multipart file %s/%s: %v", field, filename, err)
}
if _, err := part.Write([]byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'}); err != nil {
t.Fatalf("write multipart file %s/%s: %v", field, filename, err)
}
}
func TestCleanupExpiredLocalTempAssetsDeletesExpiredStaticFiles(t *testing.T) { func TestCleanupExpiredLocalTempAssetsDeletesExpiredStaticFiles(t *testing.T) {
uploadedDir := t.TempDir() uploadedDir := t.TempDir()
generatedDir := t.TempDir() generatedDir := t.TempDir()

View File

@ -0,0 +1,290 @@
package httpapi
import (
"context"
"encoding/json"
"io"
"mime"
"mime/multipart"
"net/http"
"strconv"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
)
const multipartTaskMemoryBytes = 32 << 20
type imageEditMultipartAssetUploader func(context.Context, string, *multipart.FileHeader) (map[string]any, error)
func (s *Server) decodeTaskRequestBody(ctx context.Context, w http.ResponseWriter, r *http.Request, kind string) (map[string]any, error) {
if requestIsMultipartForm(r) {
if kind != "images.edits" {
return nil, &clients.ClientError{Code: "unsupported_multipart_body", Message: "multipart/form-data is only supported for image edit tasks", Retryable: false}
}
return s.decodeImageEditMultipartBody(ctx, w, r)
}
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return nil, &clients.ClientError{Code: "invalid_json_body", Message: "invalid json body", Retryable: false}
}
if body == nil {
body = map[string]any{}
}
return body, nil
}
func requestIsMultipartForm(r *http.Request) bool {
contentType := strings.TrimSpace(r.Header.Get("Content-Type"))
if contentType == "" {
return false
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return strings.HasPrefix(strings.ToLower(contentType), "multipart/form-data")
}
return strings.EqualFold(mediaType, "multipart/form-data")
}
func (s *Server) decodeImageEditMultipartBody(ctx context.Context, w http.ResponseWriter, r *http.Request) (map[string]any, error) {
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
if err := r.ParseMultipartForm(multipartTaskMemoryBytes); err != nil {
return nil, &clients.ClientError{Code: "invalid_multipart_body", Message: "invalid multipart form-data body", Retryable: false}
}
if r.MultipartForm == nil {
return map[string]any{}, nil
}
defer r.MultipartForm.RemoveAll()
return imageEditMultipartFormBody(ctx, r.MultipartForm, s.uploadImageEditMultipartAsset)
}
func imageEditMultipartFormBody(ctx context.Context, form *multipart.Form, upload imageEditMultipartAssetUploader) (map[string]any, error) {
body := map[string]any{}
if form == nil {
return body, nil
}
for key, values := range form.Value {
addImageEditMultipartFieldValues(body, key, values)
}
if upload == nil {
return body, nil
}
if err := addImageEditMultipartFiles(ctx, body, form.File, upload); err != nil {
return nil, err
}
return body, nil
}
func addImageEditMultipartFieldValues(body map[string]any, rawKey string, values []string) {
key := normalizeImageEditMultipartFieldName(rawKey)
parsed := make([]any, 0, len(values))
for _, value := range values {
if strings.TrimSpace(value) == "" {
continue
}
parsed = append(parsed, parseImageEditMultipartFieldValue(key, value))
}
if len(parsed) == 0 {
return
}
switch key {
case "image":
if len(parsed) == 1 {
body["image"] = parsed[0]
return
}
appendImageEditMultipartList(body, "images", parsed...)
case "images":
appendImageEditMultipartList(body, "images", flattenImageEditMultipartValues(parsed)...)
case "mask":
body["mask"] = parsed[0]
default:
if len(parsed) == 1 {
body[key] = parsed[0]
} else {
body[key] = parsed
}
}
}
func normalizeImageEditMultipartFieldName(key string) string {
switch strings.TrimSpace(key) {
case "Image":
return "image"
case "images", "images[]", "image[]", "files":
return "images"
default:
return strings.TrimSpace(key)
}
}
func parseImageEditMultipartFieldValue(key string, value string) any {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return ""
}
if parsed, ok := parseImageEditMultipartJSONValue(trimmed); ok {
return parsed
}
if isImageEditMultipartBooleanField(key) {
if parsed, err := strconv.ParseBool(trimmed); err == nil {
return parsed
}
}
if isImageEditMultipartNumberField(key) {
if parsed, err := strconv.ParseFloat(trimmed, 64); err == nil {
return parsed
}
}
return trimmed
}
func parseImageEditMultipartJSONValue(value string) (any, bool) {
if value == "" {
return nil, false
}
switch value[0] {
case '{', '[', '"':
default:
return nil, false
}
var parsed any
if err := json.Unmarshal([]byte(value), &parsed); err != nil {
return nil, false
}
return parsed, true
}
func isImageEditMultipartBooleanField(key string) bool {
switch key {
case "stream", "simulation", "testMode", "test_mode", "watermark", "sync":
return true
default:
return false
}
}
func isImageEditMultipartNumberField(key string) bool {
switch key {
case "n", "count", "width", "height", "seed", "batch_size", "batchSize", "simulationDurationMs", "simulation_duration_ms", "duration":
return true
default:
return false
}
}
func addImageEditMultipartFiles(ctx context.Context, body map[string]any, files map[string][]*multipart.FileHeader, upload imageEditMultipartAssetUploader) error {
imageFiles := collectImageEditMultipartFiles(files, "image", "Image")
if len(imageFiles) == 1 {
value, err := upload(ctx, "image", imageFiles[0])
if err != nil {
return err
}
body["image"] = value
} else if len(imageFiles) > 1 {
values, err := uploadImageEditMultipartFiles(ctx, "images", imageFiles, upload)
if err != nil {
return err
}
appendImageEditMultipartList(body, "images", values...)
}
multiImageFiles := collectImageEditMultipartFiles(files, "images", "images[]", "image[]", "files")
if len(multiImageFiles) > 0 {
values, err := uploadImageEditMultipartFiles(ctx, "images", multiImageFiles, upload)
if err != nil {
return err
}
appendImageEditMultipartList(body, "images", values...)
}
maskFiles := collectImageEditMultipartFiles(files, "mask")
if len(maskFiles) > 0 {
value, err := upload(ctx, "mask", maskFiles[0])
if err != nil {
return err
}
body["mask"] = value
}
return nil
}
func collectImageEditMultipartFiles(files map[string][]*multipart.FileHeader, keys ...string) []*multipart.FileHeader {
out := make([]*multipart.FileHeader, 0)
for _, key := range keys {
out = append(out, files[key]...)
}
return out
}
func uploadImageEditMultipartFiles(ctx context.Context, field string, headers []*multipart.FileHeader, upload imageEditMultipartAssetUploader) ([]any, error) {
values := make([]any, 0, len(headers))
for _, header := range headers {
value, err := upload(ctx, field, header)
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
func (s *Server) uploadImageEditMultipartAsset(ctx context.Context, field string, header *multipart.FileHeader) (map[string]any, error) {
file, err := header.Open()
if err != nil {
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
}
defer file.Close()
payload, err := io.ReadAll(file)
if err != nil {
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
}
contentType := strings.TrimSpace(header.Header.Get("Content-Type"))
detectedContentType := ""
if len(payload) > 0 {
detectedContentType = http.DetectContentType(payload)
}
if contentType != "" && !strings.HasPrefix(strings.ToLower(contentType), "image/") && !strings.HasPrefix(strings.ToLower(detectedContentType), "image/") {
return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false}
}
contentType = requestAssetContentType(contentType, payload, field, []string{field}, nil)
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false}
}
ref, err := s.ensurePublicRequestAsset(ctx, decodedRequestAsset{
Bytes: payload,
ContentType: contentType,
})
if err != nil {
return nil, err
}
return requestAssetWrapper(ref), nil
}
func appendImageEditMultipartList(body map[string]any, key string, values ...any) {
list := flattenImageEditMultipartValues([]any{body[key]})
list = append(list, flattenImageEditMultipartValues(values)...)
if len(list) == 0 {
return
}
body[key] = list
}
func flattenImageEditMultipartValues(values []any) []any {
out := make([]any, 0, len(values))
for _, value := range values {
switch typed := value.(type) {
case nil:
continue
case []any:
out = append(out, flattenImageEditMultipartValues(typed)...)
case []string:
for _, item := range typed {
if text := strings.TrimSpace(item); text != "" {
out = append(out, text)
}
}
default:
out = append(out, value)
}
}
return out
}