486 lines
14 KiB
Go
486 lines
14 KiB
Go
package httpapi
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
)
|
|
|
|
const multipartTaskMemoryBytes = 32 << 20
|
|
|
|
type imageEditMultipartAssetUploader func(context.Context, string, *multipart.FileHeader) (map[string]any, error)
|
|
type voiceCloneMultipartAssetUploader func(context.Context, string, *multipart.FileHeader) (map[string]any, error)
|
|
|
|
func (s *Server) decodeTaskRequestBody(ctx context.Context, w http.ResponseWriter, r *http.Request, kind string) (map[string]any, error) {
|
|
if requestIsMultipartForm(r) {
|
|
switch kind {
|
|
case "images.edits":
|
|
return s.decodeImageEditMultipartBody(ctx, w, r)
|
|
case "voice.clone":
|
|
return s.decodeVoiceCloneMultipartBody(ctx, w, r)
|
|
default:
|
|
return nil, &clients.ClientError{Code: "unsupported_multipart_body", Message: "multipart/form-data is only supported for image edit and voice clone tasks", Retryable: false}
|
|
}
|
|
}
|
|
var body map[string]any
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_json_body", Message: "invalid json body", Retryable: false}
|
|
}
|
|
if body == nil {
|
|
body = map[string]any{}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func requestIsMultipartForm(r *http.Request) bool {
|
|
contentType := strings.TrimSpace(r.Header.Get("Content-Type"))
|
|
if contentType == "" {
|
|
return false
|
|
}
|
|
mediaType, _, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
return strings.HasPrefix(strings.ToLower(contentType), "multipart/form-data")
|
|
}
|
|
return strings.EqualFold(mediaType, "multipart/form-data")
|
|
}
|
|
|
|
func (s *Server) decodeImageEditMultipartBody(ctx context.Context, w http.ResponseWriter, r *http.Request) (map[string]any, error) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
|
|
if err := r.ParseMultipartForm(multipartTaskMemoryBytes); err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_body", Message: "invalid multipart form-data body", Retryable: false}
|
|
}
|
|
if r.MultipartForm == nil {
|
|
return map[string]any{}, nil
|
|
}
|
|
defer r.MultipartForm.RemoveAll()
|
|
return imageEditMultipartFormBody(ctx, r.MultipartForm, s.uploadImageEditMultipartAsset)
|
|
}
|
|
|
|
func imageEditMultipartFormBody(ctx context.Context, form *multipart.Form, upload imageEditMultipartAssetUploader) (map[string]any, error) {
|
|
body := map[string]any{}
|
|
if form == nil {
|
|
return body, nil
|
|
}
|
|
for key, values := range form.Value {
|
|
addImageEditMultipartFieldValues(body, key, values)
|
|
}
|
|
if upload == nil {
|
|
return body, nil
|
|
}
|
|
if err := addImageEditMultipartFiles(ctx, body, form.File, upload); err != nil {
|
|
return nil, err
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func addImageEditMultipartFieldValues(body map[string]any, rawKey string, values []string) {
|
|
key := normalizeImageEditMultipartFieldName(rawKey)
|
|
parsed := make([]any, 0, len(values))
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) == "" {
|
|
continue
|
|
}
|
|
parsed = append(parsed, parseImageEditMultipartFieldValue(key, value))
|
|
}
|
|
if len(parsed) == 0 {
|
|
return
|
|
}
|
|
switch key {
|
|
case "image":
|
|
if len(parsed) == 1 {
|
|
body["image"] = parsed[0]
|
|
return
|
|
}
|
|
appendImageEditMultipartList(body, "images", parsed...)
|
|
case "images":
|
|
appendImageEditMultipartList(body, "images", flattenImageEditMultipartValues(parsed)...)
|
|
case "mask":
|
|
body["mask"] = parsed[0]
|
|
default:
|
|
if len(parsed) == 1 {
|
|
body[key] = parsed[0]
|
|
} else {
|
|
body[key] = parsed
|
|
}
|
|
}
|
|
}
|
|
|
|
func normalizeImageEditMultipartFieldName(key string) string {
|
|
switch strings.TrimSpace(key) {
|
|
case "Image":
|
|
return "image"
|
|
case "images", "images[]", "image[]", "files":
|
|
return "images"
|
|
default:
|
|
return strings.TrimSpace(key)
|
|
}
|
|
}
|
|
|
|
func parseImageEditMultipartFieldValue(key string, value string) any {
|
|
trimmed := strings.TrimSpace(value)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
if parsed, ok := parseImageEditMultipartJSONValue(trimmed); ok {
|
|
return parsed
|
|
}
|
|
if isImageEditMultipartBooleanField(key) {
|
|
if parsed, err := strconv.ParseBool(trimmed); err == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
if isImageEditMultipartNumberField(key) {
|
|
if parsed, err := strconv.ParseFloat(trimmed, 64); err == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
return trimmed
|
|
}
|
|
|
|
func parseImageEditMultipartJSONValue(value string) (any, bool) {
|
|
if value == "" {
|
|
return nil, false
|
|
}
|
|
switch value[0] {
|
|
case '{', '[', '"':
|
|
default:
|
|
return nil, false
|
|
}
|
|
var parsed any
|
|
if err := json.Unmarshal([]byte(value), &parsed); err != nil {
|
|
return nil, false
|
|
}
|
|
return parsed, true
|
|
}
|
|
|
|
func isImageEditMultipartBooleanField(key string) bool {
|
|
switch key {
|
|
case "stream", "simulation", "testMode", "test_mode", "watermark", "sync":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isImageEditMultipartNumberField(key string) bool {
|
|
switch key {
|
|
case "n", "count", "width", "height", "seed", "batch_size", "batchSize", "simulationDurationMs", "simulation_duration_ms", "duration":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func addImageEditMultipartFiles(ctx context.Context, body map[string]any, files map[string][]*multipart.FileHeader, upload imageEditMultipartAssetUploader) error {
|
|
imageFiles := collectImageEditMultipartFiles(files, "image", "Image")
|
|
if len(imageFiles) == 1 {
|
|
value, err := upload(ctx, "image", imageFiles[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
body["image"] = value
|
|
} else if len(imageFiles) > 1 {
|
|
values, err := uploadImageEditMultipartFiles(ctx, "images", imageFiles, upload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
appendImageEditMultipartList(body, "images", values...)
|
|
}
|
|
multiImageFiles := collectImageEditMultipartFiles(files, "images", "images[]", "image[]", "files")
|
|
if len(multiImageFiles) > 0 {
|
|
values, err := uploadImageEditMultipartFiles(ctx, "images", multiImageFiles, upload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
appendImageEditMultipartList(body, "images", values...)
|
|
}
|
|
maskFiles := collectImageEditMultipartFiles(files, "mask")
|
|
if len(maskFiles) > 0 {
|
|
value, err := upload(ctx, "mask", maskFiles[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
body["mask"] = value
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func collectImageEditMultipartFiles(files map[string][]*multipart.FileHeader, keys ...string) []*multipart.FileHeader {
|
|
out := make([]*multipart.FileHeader, 0)
|
|
for _, key := range keys {
|
|
out = append(out, files[key]...)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func uploadImageEditMultipartFiles(ctx context.Context, field string, headers []*multipart.FileHeader, upload imageEditMultipartAssetUploader) ([]any, error) {
|
|
values := make([]any, 0, len(headers))
|
|
for _, header := range headers {
|
|
value, err := upload(ctx, field, header)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
values = append(values, value)
|
|
}
|
|
return values, nil
|
|
}
|
|
|
|
func (s *Server) uploadImageEditMultipartAsset(ctx context.Context, field string, header *multipart.FileHeader) (map[string]any, error) {
|
|
file, err := header.Open()
|
|
if err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
|
|
}
|
|
defer file.Close()
|
|
payload, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
|
|
}
|
|
contentType := strings.TrimSpace(header.Header.Get("Content-Type"))
|
|
detectedContentType := ""
|
|
if len(payload) > 0 {
|
|
detectedContentType = http.DetectContentType(payload)
|
|
}
|
|
if contentType != "" && !strings.HasPrefix(strings.ToLower(contentType), "image/") && !strings.HasPrefix(strings.ToLower(detectedContentType), "image/") {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false}
|
|
}
|
|
contentType = requestAssetContentType(contentType, payload, field, []string{field}, nil)
|
|
if !strings.HasPrefix(strings.ToLower(contentType), "image/") {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_image", Message: "image edit multipart files must be images", Retryable: false}
|
|
}
|
|
ref, err := s.ensurePublicRequestAsset(ctx, decodedRequestAsset{
|
|
Bytes: payload,
|
|
ContentType: contentType,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return requestAssetWrapper(ref), nil
|
|
}
|
|
|
|
func (s *Server) decodeVoiceCloneMultipartBody(ctx context.Context, w http.ResponseWriter, r *http.Request) (map[string]any, error) {
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxGatewayUploadBytes)
|
|
if err := r.ParseMultipartForm(multipartTaskMemoryBytes); err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_body", Message: "invalid multipart form-data body", Retryable: false}
|
|
}
|
|
if r.MultipartForm == nil {
|
|
return map[string]any{}, nil
|
|
}
|
|
defer r.MultipartForm.RemoveAll()
|
|
return voiceCloneMultipartFormBody(ctx, r.MultipartForm, s.uploadVoiceCloneMultipartAsset)
|
|
}
|
|
|
|
func voiceCloneMultipartFormBody(ctx context.Context, form *multipart.Form, upload voiceCloneMultipartAssetUploader) (map[string]any, error) {
|
|
body := map[string]any{}
|
|
if form == nil {
|
|
return body, nil
|
|
}
|
|
for key, values := range form.Value {
|
|
addVoiceCloneMultipartFieldValues(body, key, values)
|
|
}
|
|
if upload == nil {
|
|
return body, nil
|
|
}
|
|
if err := addVoiceCloneMultipartFiles(ctx, body, form.File, upload); err != nil {
|
|
return nil, err
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
func addVoiceCloneMultipartFieldValues(body map[string]any, rawKey string, values []string) {
|
|
key := normalizeVoiceCloneMultipartFieldName(rawKey)
|
|
parsed := make([]any, 0, len(values))
|
|
for _, value := range values {
|
|
if strings.TrimSpace(value) == "" {
|
|
continue
|
|
}
|
|
parsed = append(parsed, parseVoiceCloneMultipartFieldValue(key, value))
|
|
}
|
|
if len(parsed) == 0 {
|
|
return
|
|
}
|
|
if len(parsed) == 1 {
|
|
body[key] = parsed[0]
|
|
return
|
|
}
|
|
body[key] = parsed
|
|
}
|
|
|
|
func normalizeVoiceCloneMultipartFieldName(key string) string {
|
|
switch strings.TrimSpace(key) {
|
|
case "voiceId":
|
|
return "voice_id"
|
|
case "audioUrl":
|
|
return "audio_url"
|
|
case "promptAudioUrl":
|
|
return "prompt_audio_url"
|
|
case "promptText":
|
|
return "prompt_text"
|
|
case "previewModel":
|
|
return "preview_model"
|
|
case "textValidation":
|
|
return "text_validation"
|
|
case "languageBoost":
|
|
return "language_boost"
|
|
case "needNoiseReduction":
|
|
return "need_noise_reduction"
|
|
case "needVolumeNormalization":
|
|
return "need_volume_normalization"
|
|
case "aigcWatermark":
|
|
return "aigc_watermark"
|
|
case "fileId":
|
|
return "file_id"
|
|
case "promptFileId":
|
|
return "prompt_file_id"
|
|
case "displayName":
|
|
return "display_name"
|
|
default:
|
|
return strings.TrimSpace(key)
|
|
}
|
|
}
|
|
|
|
func parseVoiceCloneMultipartFieldValue(key string, value string) any {
|
|
trimmed := strings.TrimSpace(value)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
if parsed, ok := parseImageEditMultipartJSONValue(trimmed); ok {
|
|
return parsed
|
|
}
|
|
switch key {
|
|
case "need_noise_reduction", "need_volume_normalization", "aigc_watermark":
|
|
if parsed, err := strconv.ParseBool(trimmed); err == nil {
|
|
return parsed
|
|
}
|
|
case "file_id", "prompt_file_id":
|
|
if parsed, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
|
return parsed
|
|
}
|
|
case "accuracy":
|
|
if parsed, err := strconv.ParseFloat(trimmed, 64); err == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
return trimmed
|
|
}
|
|
|
|
func addVoiceCloneMultipartFiles(ctx context.Context, body map[string]any, files map[string][]*multipart.FileHeader, upload voiceCloneMultipartAssetUploader) error {
|
|
sourceFiles := collectVoiceCloneMultipartFiles(files, "file", "audio", "source_audio", "sourceAudio")
|
|
if len(sourceFiles) > 0 {
|
|
value, err := upload(ctx, "audio", sourceFiles[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
body["audio"] = value
|
|
}
|
|
promptFiles := collectVoiceCloneMultipartFiles(files, "prompt_audio", "promptAudio")
|
|
if len(promptFiles) > 0 {
|
|
value, err := upload(ctx, "prompt_audio", promptFiles[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
body["prompt_audio"] = value
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func collectVoiceCloneMultipartFiles(files map[string][]*multipart.FileHeader, keys ...string) []*multipart.FileHeader {
|
|
out := make([]*multipart.FileHeader, 0)
|
|
for _, key := range keys {
|
|
out = append(out, files[key]...)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *Server) uploadVoiceCloneMultipartAsset(ctx context.Context, field string, header *multipart.FileHeader) (map[string]any, error) {
|
|
file, err := header.Open()
|
|
if err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
|
|
}
|
|
defer file.Close()
|
|
payload, err := io.ReadAll(file)
|
|
if err != nil {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_file", Message: err.Error(), Retryable: false}
|
|
}
|
|
contentType := strings.TrimSpace(header.Header.Get("Content-Type"))
|
|
detectedContentType := ""
|
|
if len(payload) > 0 {
|
|
detectedContentType = http.DetectContentType(payload)
|
|
}
|
|
if !voiceCloneMultipartAudioAllowed(contentType, detectedContentType, header.Filename) {
|
|
return nil, &clients.ClientError{Code: "invalid_multipart_audio", Message: "voice clone multipart files must be mp3, m4a, or wav audio", Retryable: false}
|
|
}
|
|
contentType = requestAssetContentType(contentType, payload, field, []string{field}, nil)
|
|
if !voiceCloneMultipartAudioAllowed(contentType, detectedContentType, header.Filename) {
|
|
contentType = voiceCloneContentTypeFromExtension(header.Filename)
|
|
}
|
|
ref, err := s.ensureRequestAsset(ctx, decodedRequestAsset{
|
|
Bytes: payload,
|
|
ContentType: contentType,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return requestAssetWrapper(ref), nil
|
|
}
|
|
|
|
func voiceCloneMultipartAudioAllowed(contentType string, detectedContentType string, filename string) bool {
|
|
for _, value := range []string{contentType, detectedContentType} {
|
|
normalized := strings.ToLower(strings.TrimSpace(value))
|
|
if strings.HasPrefix(normalized, "audio/") {
|
|
return true
|
|
}
|
|
}
|
|
return voiceCloneContentTypeFromExtension(filename) != ""
|
|
}
|
|
|
|
func voiceCloneContentTypeFromExtension(filename string) string {
|
|
switch strings.ToLower(filepath.Ext(strings.TrimSpace(filename))) {
|
|
case ".mp3":
|
|
return "audio/mpeg"
|
|
case ".m4a":
|
|
return "audio/mp4"
|
|
case ".wav":
|
|
return "audio/wav"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func appendImageEditMultipartList(body map[string]any, key string, values ...any) {
|
|
list := flattenImageEditMultipartValues([]any{body[key]})
|
|
list = append(list, flattenImageEditMultipartValues(values)...)
|
|
if len(list) == 0 {
|
|
return
|
|
}
|
|
body[key] = list
|
|
}
|
|
|
|
func flattenImageEditMultipartValues(values []any) []any {
|
|
out := make([]any, 0, len(values))
|
|
for _, value := range values {
|
|
switch typed := value.(type) {
|
|
case nil:
|
|
continue
|
|
case []any:
|
|
out = append(out, flattenImageEditMultipartValues(typed)...)
|
|
case []string:
|
|
for _, item := range typed {
|
|
if text := strings.TrimSpace(item); text != "" {
|
|
out = append(out, text)
|
|
}
|
|
}
|
|
default:
|
|
out = append(out, value)
|
|
}
|
|
}
|
|
return out
|
|
}
|