easyai-ai-gateway/apps/api/internal/clients/media_clients.go

491 lines
18 KiB
Go

package clients
import (
"context"
"encoding/base64"
"encoding/hex"
"net/http"
"net/url"
"strings"
"time"
)
type JimengClient struct{ HTTPClient *http.Client }
type BlackforestClient struct{ HTTPClient *http.Client }
type HunyuanImageClient struct{ HTTPClient *http.Client }
type HunyuanVideoClient struct{ HTTPClient *http.Client }
type MinimaxClient struct{ HTTPClient *http.Client }
type MidjourneyClient struct{ HTTPClient *http.Client }
type ViduClient struct{ HTTPClient *http.Client }
type AliyunBailianClient struct{ HTTPClient *http.Client }
type NewAPIClient struct{ HTTPClient *http.Client }
type SunoClient struct{ HTTPClient *http.Client }
func (c JimengClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: jimengSpec()}.Run(ctx, request)
}
func (c BlackforestClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: blackforestSpec()}.Run(ctx, request)
}
func (c HunyuanImageClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanImageSpec()}.Run(ctx, request)
}
func (c HunyuanVideoClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanVideoSpec()}.Run(ctx, request)
}
func (c MinimaxClient) Run(ctx context.Context, request Request) (Response, error) {
if request.Kind == "speech.generations" {
return c.runSpeech(ctx, request)
}
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: minimaxSpec()}.Run(ctx, request)
}
func (c MidjourneyClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: midjourneySpec()}.Run(ctx, request)
}
func (c ViduClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: viduSpec()}.Run(ctx, request)
}
func (c AliyunBailianClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: aliyunBailianSpec()}.Run(ctx, request)
}
func (c NewAPIClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: newAPISpec()}.Run(ctx, request)
}
func (c SunoClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: sunoSpec()}.Run(ctx, request)
}
func jimengSpec() providerTaskSpec {
return providerTaskSpec{
Name: "jimeng",
SubmitPath: func(request Request, _ map[string]any) string {
return configuredPath(request, "?Action=CVSubmitTask&Version=2022-08-31", "submitPath", "submit_path")
},
PollPath: func(request Request, _ string, _ map[string]any) string {
return configuredPath(request, "?Action=CVSync2AsyncGetResult&Version=2022-08-31", "pollPath", "poll_path")
},
Auth: "bearer",
TaskIDPaths: []string{"data.task_id"},
StatusPaths: []string{"data.status"},
SuccessStatuses: []string{"done"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["req_key"] = upstreamModelName(request.Candidate)
if body["prompt"] == nil {
body["prompt"] = mediaPromptText(body)
}
return body
},
}
}
func blackforestSpec() providerTaskSpec {
return providerTaskSpec{
Name: "blackforest",
SubmitPath: func(request Request, body map[string]any) string {
return configuredPath(request, "/"+upstreamModelName(request.Candidate), "submitPath", "submit_path")
},
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return upstreamTaskID },
Auth: "x-key",
TaskIDPaths: []string{"polling_url"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"ready"},
FailureStatuses: []string{"error", "task not found"},
}
}
func hunyuanImageSpec() providerTaskSpec {
return providerTaskSpec{
Name: "tencent-hunyuan-image",
SubmitPath: func(request Request, _ map[string]any) string {
return configuredPath(request, "?Action=SubmitHunyuanImageJob&Version=2023-09-01", "submitPath", "submit_path")
},
PollPath: func(request Request, _ string, _ map[string]any) string {
return configuredPath(request, "?Action=QueryHunyuanImageJob&Version=2023-09-01&JobId=${taskId}", "pollPath", "poll_path")
},
Auth: "bearer",
TaskIDPaths: []string{"Response.JobId"},
StatusPaths: []string{"Response.Status"},
SuccessStatuses: []string{"done"},
FailureStatuses: []string{"fail"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["Prompt"] = mediaPromptText(body)
body["Model"] = upstreamModelName(request.Candidate)
return body
},
}
}
func hunyuanVideoSpec() providerTaskSpec {
return providerTaskSpec{
Name: "tencent-hunyuan-video",
SubmitPath: func(request Request, _ map[string]any) string {
return configuredPath(request, "?Action=SubmitTextToVideoJob&Version=2024-01-01", "submitPath", "submit_path")
},
PollPath: func(request Request, _ string, _ map[string]any) string {
return configuredPath(request, "?Action=QueryVideoJob&Version=2024-01-01&JobId=${taskId}", "pollPath", "poll_path")
},
Auth: "bearer",
TaskIDPaths: []string{"Response.JobId"},
StatusPaths: []string{"Response.Status"},
SuccessStatuses: []string{"done"},
FailureStatuses: []string{"fail"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["Prompt"] = mediaPromptText(body)
body["Model"] = upstreamModelName(request.Candidate)
return body
},
}
}
func minimaxSpec() providerTaskSpec {
return providerTaskSpec{
Name: "minimax",
SubmitPath: func(Request, map[string]any) string { return "/video_generation" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/query/video_generation?task_id=" + upstreamTaskID
},
Auth: "bearer",
TaskIDPaths: []string{"task_id"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"success"},
FailureStatuses: []string{"failed", "expired"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
return minimaxVideoPayload(request, body)
},
ResolveSuccess: minimaxResolveVideoFile,
}
}
func minimaxVideoPayload(request Request, body map[string]any) map[string]any {
model := upstreamModelName(request.Candidate)
body["model"] = model
if strings.TrimSpace(stringFromAny(body["prompt"])) == "" {
body["prompt"] = mediaPromptText(body)
}
firstFrame, lastFrame := minimaxFrameImages(body)
if firstFrame != "" && strings.TrimSpace(stringFromAny(body["first_frame_image"])) == "" {
body["first_frame_image"] = firstFrame
}
if lastFrame != "" && strings.TrimSpace(stringFromAny(body["last_frame_image"])) == "" {
body["last_frame_image"] = lastFrame
}
if resolution := minimaxVideoResolution(model, stringFromAny(body["resolution"])); resolution != "" {
body["resolution"] = resolution
}
for _, key := range []string{
"content", "input", "_paramWarnings",
"first_frame", "firstFrame", "last_frame", "lastFrame",
"image", "images", "image_url", "imageUrl", "image_urls", "imageUrls",
"reference_image", "referenceImage", "duration_seconds", "durationSeconds",
} {
delete(body, key)
}
return body
}
func minimaxFrameImages(body map[string]any) (string, string) {
firstFrame := firstNonEmptyStringValue(body, "first_frame_image", "firstFrameImage", "first_frame", "firstFrame")
lastFrame := firstNonEmptyStringValue(body, "last_frame_image", "lastFrameImage", "last_frame", "lastFrame")
imageURLs := firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"])
if firstFrame == "" && len(imageURLs) > 0 {
firstFrame = imageURLs[0]
}
for _, item := range contentItems(body["content"]) {
if strings.TrimSpace(stringFromAny(item["type"])) != "image_url" {
continue
}
url := minimaxNestedURL(item, "image_url")
if url == "" {
continue
}
switch strings.TrimSpace(stringFromAny(item["role"])) {
case "first_frame":
if firstFrame == "" {
firstFrame = url
}
case "last_frame":
if lastFrame == "" {
lastFrame = url
}
default:
if firstFrame == "" {
firstFrame = url
}
}
}
return firstFrame, lastFrame
}
func minimaxNestedURL(item map[string]any, key string) string {
if url := strings.TrimSpace(stringFromAny(item[key])); url != "" {
return url
}
nested := mapFromAny(item[key])
return strings.TrimSpace(stringFromAny(nested["url"]))
}
func minimaxVideoResolution(model string, resolution string) string {
resolution = strings.TrimSpace(resolution)
if resolution == "" {
return ""
}
normalized := strings.ToLower(strings.ReplaceAll(resolution, " ", ""))
isHailuo23 := strings.Contains(strings.ToLower(model), "hailuo-2.3")
if isHailuo23 {
switch normalized {
case "720", "720p", "768", "768p":
return "768P"
case "1080", "1080p":
return "1080P"
}
}
switch normalized {
case "512", "512p":
return "512P"
case "720", "720p":
return "720P"
case "768", "768p":
return "768P"
case "1080", "1080p":
return "1080P"
default:
if strings.HasSuffix(normalized, "p") {
return strings.TrimSuffix(normalized, "p") + "P"
}
return resolution
}
}
func minimaxResolveVideoFile(ctx context.Context, client *http.Client, request Request, result map[string]any) (map[string]any, string, error) {
if len(providerTaskData(request, result)) > 0 {
return result, "", nil
}
fileID := strings.TrimSpace(stringFromPathValue(valueAtPath(result, "file_id")))
if fileID == "" {
return result, "", nil
}
fetched, requestID, err := providerGetJSON(ctx, client, providerURL(request.Candidate.BaseURL, "/files/retrieve?file_id="+url.QueryEscape(fileID)), request.Candidate.Credentials, "bearer")
if err != nil {
return nil, requestID, err
}
if isProviderTaskFailure(minimaxSpec(), fetched) {
return nil, requestID, providerTaskFailure(minimaxSpec(), fetched, firstNonEmptyString(requestID, requestIDFromResult(fetched)), time.Now())
}
downloadURL := firstNonEmptyString(
valueAtPath(fetched, "file.download_url"),
valueAtPath(fetched, "file.downloadUrl"),
valueAtPath(fetched, "download_url"),
valueAtPath(fetched, "downloadUrl"),
valueAtPath(fetched, "url"),
)
if downloadURL == "" {
return nil, requestID, &ClientError{Code: "invalid_response", Message: "minimax video download url is missing", RequestID: requestID, Retryable: false}
}
out := cloneMapAny(result)
out["video_url"] = downloadURL
if file, ok := fetched["file"].(map[string]any); ok {
out["file"] = cloneMapAny(file)
}
out["file_retrieve"] = cloneMapAny(fetched)
return out, firstNonEmptyString(requestID, requestIDFromResult(fetched)), nil
}
func (c MinimaxClient) runSpeech(ctx context.Context, request Request) (Response, error) {
startedAt := time.Now()
payload := minimaxSpeechPayload(request)
result, requestID, err := providerPostJSON(ctx, httpClient(request.HTTPClient, c.HTTPClient), providerURL(request.Candidate.BaseURL, "/t2a_v2"), payload, request.Candidate.Credentials, "bearer")
finishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, requestID, startedAt, finishedAt)
}
audioHex := strings.TrimSpace(stringFromPathValue(valueAtPath(result, "data.audio")))
if audioHex == "" {
message := firstNonEmptyString(valueAtPath(result, "base_resp.status_msg"), valueAtPath(result, "message"), "minimax speech audio is missing")
return Response{}, &ClientError{Code: "invalid_response", Message: message, RequestID: firstNonEmptyString(requestID, requestIDFromResult(result)), ResponseStartedAt: startedAt, ResponseFinishedAt: finishedAt, ResponseDurationMS: responseDurationMS(startedAt, finishedAt), Retryable: false}
}
audioBytes, err := hex.DecodeString(audioHex)
if err != nil {
return Response{}, &ClientError{Code: "invalid_response", Message: "minimax speech audio hex is invalid: " + err.Error(), RequestID: firstNonEmptyString(requestID, requestIDFromResult(result)), ResponseStartedAt: startedAt, ResponseFinishedAt: finishedAt, ResponseDurationMS: responseDurationMS(startedAt, finishedAt), Retryable: false}
}
normalized := cloneMapAny(result)
normalized["status"] = "success"
normalized["created"] = time.Now().UnixMilli()
normalized["model"] = request.Model
normalized["raw_data"] = cloneMapAny(result)
normalized["data"] = []any{map[string]any{
"type": "audio",
"content": "data:audio/mpeg;base64," + base64.StdEncoding.EncodeToString(audioBytes),
"mime_type": "audio/mpeg",
"uploaded": false,
}}
return Response{
Result: normalized,
RequestID: firstNonEmptyString(requestID, requestIDFromResult(result)),
Progress: providerProgress(request),
ResponseStartedAt: startedAt,
ResponseFinishedAt: finishedAt,
ResponseDurationMS: responseDurationMS(startedAt, finishedAt),
}, nil
}
func minimaxSpeechPayload(request Request) map[string]any {
body := cloneBody(request.Body)
body["model"] = upstreamModelName(request.Candidate)
voiceID := firstNonEmptyString(body["voice_id"], body["voiceId"])
speed := firstPresent(body["speed"], float64(1))
vol := firstPresent(body["vol"], body["volume"], float64(1))
pitch := firstPresent(body["pitch"], float64(0))
voiceSetting := map[string]any{
"voice_id": voiceID,
"speed": speed,
"vol": vol,
"pitch": pitch,
}
if emotion := firstNonEmptyString(body["emotion"]); emotion != "" {
voiceSetting["emotion"] = emotion
}
delete(body, "voice_id")
delete(body, "voiceId")
delete(body, "speed")
delete(body, "vol")
delete(body, "volume")
delete(body, "pitch")
delete(body, "emotion")
body["voice_setting"] = voiceSetting
return body
}
func sunoSpec() providerTaskSpec {
return providerTaskSpec{
Name: "suno",
SubmitPath: func(Request, map[string]any) string { return "/generator/suno" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/v2/sunoinfo?id=" + upstreamTaskID
},
Auth: "bearer",
TaskIDPaths: []string{"data"},
StatusPaths: []string{"data.status"},
SuccessStatuses: []string{"succeeded", "complete", "completed"},
FailureStatuses: []string{"failed"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["task"] = "create"
body["model"] = sunoMappedModel(upstreamModelName(request.Candidate))
if body["customMode"] == nil {
body["customMode"] = false
}
if body["makeInstrumental"] == nil {
body["makeInstrumental"] = false
}
return body
},
}
}
func sunoMappedModel(model string) string {
switch strings.TrimSpace(model) {
case "chirp-v3-0", "chirp-v3-5":
return "v40"
case "chirp-v4-0":
return "v40"
case "chirp-v4-5":
return "v45"
case "chirp-v4-5+":
return "v45+"
case "chirp-v5-0":
return "v50"
default:
return model
}
}
func midjourneySpec() providerTaskSpec {
return providerTaskSpec{
Name: "midjourney",
SubmitPath: func(request Request, body map[string]any) string {
return configuredPath(request, "/diffusion", "submitPath", "submit_path")
},
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/job/" + upstreamTaskID },
Auth: "bearer",
TaskIDPaths: []string{"job_id", "id"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"success", "completed"},
FailureStatuses: []string{"failed"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
if body["prompt"] == nil && body["text"] == nil {
body["prompt"] = mediaPromptText(body)
}
return body
},
}
}
func viduSpec() providerTaskSpec {
return providerTaskSpec{
Name: "vidu",
SubmitPath: func(request Request, body map[string]any) string {
if path := configuredPath(request, "", "submitPath", "submit_path"); path != "" {
return path
}
taskType := firstNonEmptyString(body["type"], body["task_type"], "text2video")
if taskType == "multiframe" {
return "/multiframe"
}
return "/" + taskType
},
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/tasks/" + upstreamTaskID + "/creations"
},
Auth: "token",
TaskIDPaths: []string{"task_id"},
StatusPaths: []string{"state", "status"},
SuccessStatuses: []string{"success", "succeeded"},
FailureStatuses: []string{"failed"},
}
}
func aliyunBailianSpec() providerTaskSpec {
return providerTaskSpec{
Name: "aliyun-bailian",
SubmitPath: func(Request, map[string]any) string { return "/services/aigc/video-generation/video-synthesis" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/tasks/" + upstreamTaskID },
Auth: "bearer",
TaskIDPaths: []string{"output.task_id"},
StatusPaths: []string{"output.task_status"},
SuccessStatuses: []string{"succeeded", "success"},
FailureStatuses: []string{"failed"},
}
}
func newAPISpec() providerTaskSpec {
return providerTaskSpec{
Name: "newapi",
SubmitPath: func(Request, map[string]any) string { return "/videos/generations" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/videos/generations/" + upstreamTaskID
},
Auth: "bearer",
TaskIDPaths: []string{"task_id"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"success"},
FailureStatuses: []string{"failure", "failed"},
}
}
func configuredPath(request Request, fallback string, keys ...string) string {
for _, key := range keys {
if value := strings.TrimSpace(stringFromAny(request.Candidate.PlatformConfig[key])); value != "" {
return value
}
}
return fallback
}