easyai-ai-gateway/apps/api/internal/httpapi/task_multipart.go

291 lines
8.5 KiB
Go

package httpapi
import (
"context"
"encoding/json"
"io"
"mime"
"mime/multipart"
"net/http"
"strconv"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
)
const multipartTaskMemoryBytes = 32 << 20
type imageEditMultipartAssetUploader func(context.Context, string, *multipart.FileHeader) (map[string]any, error)
func (s *Server) decodeTaskRequestBody(ctx context.Context, w http.ResponseWriter, r *http.Request, kind string) (map[string]any, error) {
if requestIsMultipartForm(r) {
if kind != "images.edits" {
return nil, &clients.ClientError{Code: "unsupported_multipart_body", Message: "multipart/form-data is only supported for image edit tasks", Retryable: false}
}
return s.decodeImageEditMultipartBody(ctx, w, r)
}
var body map[string]any
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return nil, &clients.ClientError{Code: "invalid_json_body", Message: "invalid json body", Retryable: false}
}
if body == nil {
body = map[string]any{}
}
return body, nil
}
func requestIsMultipartForm(r *http.Request) bool {
contentType := strings.TrimSpace(r.Header.Get("Content-Type"))
if contentType == "" {
return false
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return strings.HasPrefix(strings.ToLower(contentType), "multipart/form-data")
}
return strings.EqualFold(mediaType, "multipart/form-data")
}
func (s *Server) decodeImageEditMultipartBody(ctx context.Context, w http.ResponseWriter, r *http.Request) (map[string]any, error) {
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
if err := r.ParseMultipartForm(multipartTaskMemoryBytes); err != nil {
return nil, &clients.ClientError{Code: "invalid_multipart_body", Message: "invalid multipart form-data body", Retryable: false}
}
if r.MultipartForm == nil {
return map[string]any{}, nil
}
defer r.MultipartForm.RemoveAll()
return imageEditMultipartFormBody(ctx, r.MultipartForm, s.uploadImageEditMultipartAsset)
}
func imageEditMultipartFormBody(ctx context.Context, form *multipart.Form, upload imageEditMultipartAssetUploader) (map[string]any, error) {
body := map[string]any{}
if form == nil {
return body, nil
}
for key, values := range form.Value {
addImageEditMultipartFieldValues(body, key, values)
}
if upload == nil {
return body, nil
}
if err := addImageEditMultipartFiles(ctx, body, form.File, upload); err != nil {
return nil, err
}
return body, nil
}
func addImageEditMultipartFieldValues(body map[string]any, rawKey string, values []string) {
key := normalizeImageEditMultipartFieldName(rawKey)
parsed := make([]any, 0, len(values))
for _, value := range values {
if strings.TrimSpace(value) == "" {
continue
}
parsed = append(parsed, parseImageEditMultipartFieldValue(key, value))
}
if len(parsed) == 0 {
return
}
switch key {
case "image":
if len(parsed) == 1 {
body["image"] = parsed[0]
return
}
appendImageEditMultipartList(body, "images", parsed...)
case "images":
appendImageEditMultipartList(body, "images", flattenImageEditMultipartValues(parsed)...)
case "mask":
body["mask"] = parsed[0]
default:
if len(parsed) == 1 {
body[key] = parsed[0]
} else {
body[key] = parsed
}
}
}
func normalizeImageEditMultipartFieldName(key string) string {
switch strings.TrimSpace(key) {
case "Image":
return "image"
case "images", "images[]", "image[]", "files":
return "images"
default:
return strings.TrimSpace(key)
}
}
func parseImageEditMultipartFieldValue(key string, value string) any {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return ""
}
if parsed, ok := parseImageEditMultipartJSONValue(trimmed); ok {
return parsed
}
if isImageEditMultipartBooleanField(key) {
if parsed, err := strconv.ParseBool(trimmed); err == nil {
return parsed
}
}
if isImageEditMultipartNumberField(key) {
if parsed, err := strconv.ParseFloat(trimmed, 64); err == nil {
return parsed
}
}
return trimmed
}
func parseImageEditMultipartJSONValue(value string) (any, bool) {
if value == "" {
return nil, false
}
switch value[0] {
case '{', '[', '"':
default:
return nil, false
}
var parsed any
if err := json.Unmarshal([]byte(value), &parsed); err != nil {
return nil, false
}
return parsed, true
}
func isImageEditMultipartBooleanField(key string) bool {
switch key {
case "stream", "simulation", "testMode", "test_mode", "watermark", "sync":
return true
default:
return false
}
}
func isImageEditMultipartNumberField(key string) bool {
switch key {
case "n", "count", "width", "height", "seed", "batch_size", "batchSize", "simulationDurationMs", "simulation_duration_ms", "duration":
return true
default:
return false
}
}
func addImageEditMultipartFiles(ctx context.Context, body map[string]any, files map[string][]*multipart.FileHeader, upload imageEditMultipartAssetUploader) error {
imageFiles := collectImageEditMultipartFiles(files, "image", "Image")
if len(imageFiles) == 1 {
value, err := upload(ctx, "image", imageFiles[0])
if err != nil {
return err
}
body["image"] = value
} else if len(imageFiles) > 1 {
values, err := uploadImageEditMultipartFiles(ctx, "images", imageFiles, upload)
if err != nil {
return err
}
appendImageEditMultipartList(body, "images", values...)
}
multiImageFiles := collectImageEditMultipartFiles(files, "images", "images[]", "image[]", "files")
if len(multiImageFiles) > 0 {
values, err := uploadImageEditMultipartFiles(ctx, "images", multiImageFiles, upload)
if err != nil {
return err
}
appendImageEditMultipartList(body, "images", values...)
}
maskFiles := collectImageEditMultipartFiles(files, "mask")
if len(maskFiles) > 0 {
value, err := upload(ctx, "mask", maskFiles[0])
if err != nil {
return err
}
body["mask"] = value
}
return nil
}
func collectImageEditMultipartFiles(files map[string][]*multipart.FileHeader, keys ...string) []*multipart.FileHeader {
out := make([]*multipart.FileHeader, 0)
for _, key := range keys {
out = append(out, files[key]...)
}
return out
}
func uploadImageEditMultipartFiles(ctx context.Context, field string, headers []*multipart.FileHeader, upload imageEditMultipartAssetUploader) ([]any, error) {
values := make([]any, 0, len(headers))
for _, header := range headers {
value, err := upload(ctx, field, header)
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
func (s *Server) uploadImageEditMultipartAsset(ctx context.Context, field string, header *multipart.FileHeader) (map[string]any, error) {
file, err := header.Open()
if err != nil {
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
}
defer file.Close()
payload, err := io.ReadAll(file)
if err != nil {
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
}
contentType := strings.TrimSpace(header.Header.Get("Content-Type"))
detectedContentType := ""
if len(payload) > 0 {
detectedContentType = http.DetectContentType(payload)
}
if contentType != "" && !strings.HasPrefix(strings.ToLower(contentType), "image/") && !strings.HasPrefix(strings.ToLower(detectedContentType), "image/") {
return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false}
}
contentType = requestAssetContentType(contentType, payload, field, []string{field}, nil)
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false}
}
ref, err := s.ensurePublicRequestAsset(ctx, decodedRequestAsset{
Bytes: payload,
ContentType: contentType,
})
if err != nil {
return nil, err
}
return requestAssetWrapper(ref), nil
}
func appendImageEditMultipartList(body map[string]any, key string, values ...any) {
list := flattenImageEditMultipartValues([]any{body[key]})
list = append(list, flattenImageEditMultipartValues(values)...)
if len(list) == 0 {
return
}
body[key] = list
}
func flattenImageEditMultipartValues(values []any) []any {
out := make([]any, 0, len(values))
for _, value := range values {
switch typed := value.(type) {
case nil:
continue
case []any:
out = append(out, flattenImageEditMultipartValues(typed)...)
case []string:
for _, item := range typed {
if text := strings.TrimSpace(item); text != "" {
out = append(out, text)
}
}
default:
out = append(out, value)
}
}
return out
}