easyai-ai-gateway/apps/api/internal/clients/media_clients.go

233 lines
8.6 KiB
Go

package clients
import (
"context"
"net/http"
"strings"
)
type JimengClient struct{ HTTPClient *http.Client }
type BlackforestClient struct{ HTTPClient *http.Client }
type HunyuanImageClient struct{ HTTPClient *http.Client }
type HunyuanVideoClient struct{ HTTPClient *http.Client }
type MinimaxClient struct{ HTTPClient *http.Client }
type MidjourneyClient struct{ HTTPClient *http.Client }
type ViduClient struct{ HTTPClient *http.Client }
type AliyunBailianClient struct{ HTTPClient *http.Client }
type NewAPIClient struct{ HTTPClient *http.Client }
func (c JimengClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: jimengSpec()}.Run(ctx, request)
}
func (c BlackforestClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: blackforestSpec()}.Run(ctx, request)
}
func (c HunyuanImageClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanImageSpec()}.Run(ctx, request)
}
func (c HunyuanVideoClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: hunyuanVideoSpec()}.Run(ctx, request)
}
func (c MinimaxClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: minimaxSpec()}.Run(ctx, request)
}
func (c MidjourneyClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: midjourneySpec()}.Run(ctx, request)
}
func (c ViduClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: viduSpec()}.Run(ctx, request)
}
func (c AliyunBailianClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: aliyunBailianSpec()}.Run(ctx, request)
}
func (c NewAPIClient) Run(ctx context.Context, request Request) (Response, error) {
return providerTaskClient{HTTPClient: c.HTTPClient, Spec: newAPISpec()}.Run(ctx, request)
}
func jimengSpec() providerTaskSpec {
return providerTaskSpec{
Name: "jimeng",
SubmitPath: func(request Request, _ map[string]any) string {
return configuredPath(request, "?Action=CVSubmitTask&Version=2022-08-31", "submitPath", "submit_path")
},
PollPath: func(request Request, _ string, _ map[string]any) string {
return configuredPath(request, "?Action=CVSync2AsyncGetResult&Version=2022-08-31", "pollPath", "poll_path")
},
Auth: "bearer",
TaskIDPaths: []string{"data.task_id"},
StatusPaths: []string{"data.status"},
SuccessStatuses: []string{"done"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["req_key"] = upstreamModelName(request.Candidate)
if body["prompt"] == nil {
body["prompt"] = mediaPromptText(body)
}
return body
},
}
}
func blackforestSpec() providerTaskSpec {
return providerTaskSpec{
Name: "blackforest",
SubmitPath: func(request Request, body map[string]any) string {
return configuredPath(request, "/"+upstreamModelName(request.Candidate), "submitPath", "submit_path")
},
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return upstreamTaskID },
Auth: "x-key",
TaskIDPaths: []string{"polling_url"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"ready"},
FailureStatuses: []string{"error", "task not found"},
}
}
func hunyuanImageSpec() providerTaskSpec {
return providerTaskSpec{
Name: "tencent-hunyuan-image",
SubmitPath: func(request Request, _ map[string]any) string {
return configuredPath(request, "?Action=SubmitHunyuanImageJob&Version=2023-09-01", "submitPath", "submit_path")
},
PollPath: func(request Request, _ string, _ map[string]any) string {
return configuredPath(request, "?Action=QueryHunyuanImageJob&Version=2023-09-01&JobId=${taskId}", "pollPath", "poll_path")
},
Auth: "bearer",
TaskIDPaths: []string{"Response.JobId"},
StatusPaths: []string{"Response.Status"},
SuccessStatuses: []string{"done"},
FailureStatuses: []string{"fail"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["Prompt"] = mediaPromptText(body)
body["Model"] = upstreamModelName(request.Candidate)
return body
},
}
}
func hunyuanVideoSpec() providerTaskSpec {
return providerTaskSpec{
Name: "tencent-hunyuan-video",
SubmitPath: func(request Request, _ map[string]any) string {
return configuredPath(request, "?Action=SubmitTextToVideoJob&Version=2024-01-01", "submitPath", "submit_path")
},
PollPath: func(request Request, _ string, _ map[string]any) string {
return configuredPath(request, "?Action=QueryVideoJob&Version=2024-01-01&JobId=${taskId}", "pollPath", "poll_path")
},
Auth: "bearer",
TaskIDPaths: []string{"Response.JobId"},
StatusPaths: []string{"Response.Status"},
SuccessStatuses: []string{"done"},
FailureStatuses: []string{"fail"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
body["Prompt"] = mediaPromptText(body)
body["Model"] = upstreamModelName(request.Candidate)
return body
},
}
}
func minimaxSpec() providerTaskSpec {
return providerTaskSpec{
Name: "minimax",
SubmitPath: func(Request, map[string]any) string { return "/video_generation" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/query/video_generation?task_id=" + upstreamTaskID
},
Auth: "bearer",
TaskIDPaths: []string{"task_id"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"success"},
FailureStatuses: []string{"failed", "expired"},
}
}
func midjourneySpec() providerTaskSpec {
return providerTaskSpec{
Name: "midjourney",
SubmitPath: func(request Request, body map[string]any) string {
return configuredPath(request, "/diffusion", "submitPath", "submit_path")
},
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/job/" + upstreamTaskID },
Auth: "bearer",
TaskIDPaths: []string{"job_id", "id"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"success", "completed"},
FailureStatuses: []string{"failed"},
DefaultSubmitBody: func(request Request, body map[string]any) map[string]any {
if body["prompt"] == nil && body["text"] == nil {
body["prompt"] = mediaPromptText(body)
}
return body
},
}
}
func viduSpec() providerTaskSpec {
return providerTaskSpec{
Name: "vidu",
SubmitPath: func(request Request, body map[string]any) string {
if path := configuredPath(request, "", "submitPath", "submit_path"); path != "" {
return path
}
taskType := firstNonEmptyString(body["type"], body["task_type"], "text2video")
if taskType == "multiframe" {
return "/multiframe"
}
return "/" + taskType
},
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/tasks/" + upstreamTaskID + "/creations"
},
Auth: "token",
TaskIDPaths: []string{"task_id"},
StatusPaths: []string{"state", "status"},
SuccessStatuses: []string{"success", "succeeded"},
FailureStatuses: []string{"failed"},
}
}
func aliyunBailianSpec() providerTaskSpec {
return providerTaskSpec{
Name: "aliyun-bailian",
SubmitPath: func(Request, map[string]any) string { return "/services/aigc/video-generation/video-synthesis" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string { return "/tasks/" + upstreamTaskID },
Auth: "bearer",
TaskIDPaths: []string{"output.task_id"},
StatusPaths: []string{"output.task_status"},
SuccessStatuses: []string{"succeeded", "success"},
FailureStatuses: []string{"failed"},
}
}
func newAPISpec() providerTaskSpec {
return providerTaskSpec{
Name: "newapi",
SubmitPath: func(Request, map[string]any) string { return "/videos/generations" },
PollPath: func(_ Request, upstreamTaskID string, _ map[string]any) string {
return "/videos/generations/" + upstreamTaskID
},
Auth: "bearer",
TaskIDPaths: []string{"task_id"},
StatusPaths: []string{"status"},
SuccessStatuses: []string{"success"},
FailureStatuses: []string{"failure", "failed"},
}
}
func configuredPath(request Request, fallback string, keys ...string) string {
for _, key := range keys {
if value := strings.TrimSpace(stringFromAny(request.Candidate.PlatformConfig[key])); value != "" {
return value
}
}
return fallback
}