easyai-ai-gateway/apps/api/internal/clients/volces.go

669 lines
19 KiB
Go

package clients
import (
"bytes"
"context"
"encoding/json"
"fmt"
"math"
"net/http"
"strings"
"time"
)
type VolcesClient struct {
HTTPClient *http.Client
}
func (c VolcesClient) Run(ctx context.Context, request Request) (Response, error) {
apiKey := credential(request.Candidate.Credentials, "apiKey", "api_key", "key", "token")
if apiKey == "" {
return Response{}, &ClientError{Code: "missing_credentials", Message: "volces api key is required", Retryable: false}
}
switch request.Kind {
case "images.generations", "images.edits":
return c.runImage(ctx, request, apiKey)
case "videos.generations":
return c.runVideo(ctx, request, apiKey)
default:
return Response{}, &ClientError{Code: "unsupported_kind", Message: "unsupported volces request kind", Retryable: false}
}
}
func (c VolcesClient) runImage(ctx context.Context, request Request, apiKey string) (Response, error) {
body := volcesImageBody(request)
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(request.Candidate.BaseURL, "/images/generations"), bytes.NewReader(raw))
if err != nil {
return Response{}, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := httpClient(c.HTTPClient).Do(req)
if err != nil {
return Response{}, &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
responseStartedAt := time.Now()
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
responseFinishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, requestID, responseStartedAt, responseFinishedAt)
}
if requestID == "" {
requestID = requestIDFromResult(result)
}
return Response{
Result: result,
RequestID: requestID,
Usage: usageFromOpenAI(result),
Progress: providerProgress(request),
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS(responseStartedAt, responseFinishedAt),
}, nil
}
func (c VolcesClient) runVideo(ctx context.Context, request Request, apiKey string) (Response, error) {
body := volcesVideoBody(request)
submitStartedAt := time.Now()
submitResult, submitRequestID, err := c.postJSON(ctx, request.Candidate.BaseURL, "/contents/generations/tasks", apiKey, body)
submitFinishedAt := time.Now()
if err != nil {
return Response{}, annotateResponseError(err, submitRequestID, submitStartedAt, submitFinishedAt)
}
upstreamTaskID := strings.TrimSpace(stringFromAny(submitResult["id"]))
if upstreamTaskID == "" {
return Response{}, &ClientError{Code: "invalid_response", Message: "volces video task id is missing", RequestID: submitRequestID, Retryable: false}
}
interval := volcesPollInterval(request)
timeout := volcesPollTimeout(request)
deadline := time.NewTimer(timeout)
defer deadline.Stop()
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastResult map[string]any
for {
select {
case <-ctx.Done():
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: submitRequestID, Retryable: true}
default:
}
pollStartedAt := time.Now()
pollResult, pollRequestID, err := c.getJSON(ctx, request.Candidate.BaseURL, "/contents/generations/tasks/"+upstreamTaskID, apiKey)
pollFinishedAt := time.Now()
requestID := firstNonEmpty(pollRequestID, submitRequestID, upstreamTaskID)
if err != nil {
return Response{}, annotateResponseError(err, requestID, pollStartedAt, pollFinishedAt)
}
lastResult = pollResult
switch volcesTaskStatus(pollResult) {
case "succeeded":
result := volcesVideoSuccessResult(request, upstreamTaskID, pollResult)
return Response{
Result: result,
RequestID: requestID,
Usage: volcesVideoUsage(pollResult),
Progress: volcesVideoProgress(request, upstreamTaskID),
ResponseStartedAt: submitStartedAt,
ResponseFinishedAt: pollFinishedAt,
ResponseDurationMS: responseDurationMS(submitStartedAt, pollFinishedAt),
}, nil
case "failed", "cancelled":
return Response{}, &ClientError{
Code: volcesTaskErrorCode(pollResult),
Message: volcesTaskErrorMessage(pollResult),
RequestID: requestID,
ResponseStartedAt: submitStartedAt,
ResponseFinishedAt: pollFinishedAt,
ResponseDurationMS: responseDurationMS(submitStartedAt, pollFinishedAt),
Retryable: false,
}
}
select {
case <-ctx.Done():
return Response{}, &ClientError{Code: "cancelled", Message: ctx.Err().Error(), RequestID: requestID, Retryable: true}
case <-deadline.C:
return Response{}, &ClientError{
Code: "timeout",
Message: fmt.Sprintf("volces video task %s did not finish before timeout; last status: %s", upstreamTaskID, volcesTaskStatus(lastResult)),
RequestID: requestID,
Retryable: true,
}
case <-ticker.C:
}
}
}
func (c VolcesClient) postJSON(ctx context.Context, baseURL string, path string, apiKey string, body map[string]any) (map[string]any, string, error) {
raw, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, joinURL(baseURL, path), bytes.NewReader(raw))
if err != nil {
return nil, "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := httpClient(c.HTTPClient).Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
return result, requestID, err
}
func (c VolcesClient) getJSON(ctx context.Context, baseURL string, path string, apiKey string) (map[string]any, string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, joinURL(baseURL, path), nil)
if err != nil {
return nil, "", err
}
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := httpClient(c.HTTPClient).Do(req)
if err != nil {
return nil, "", &ClientError{Code: "network", Message: err.Error(), Retryable: true}
}
requestID := requestIDFromHTTPResponse(resp)
result, err := decodeHTTPResponse(resp)
return result, requestID, err
}
func volcesImageBody(request Request) map[string]any {
body := cleanProviderBody(request.Body)
body["model"] = request.Candidate.ModelName
if _, ok := body["watermark"]; !ok {
body["watermark"] = false
}
if request.Kind == "images.generations" {
if _, ok := body["seed"]; !ok {
body["seed"] = -1
}
}
if resolution := strings.TrimSpace(stringFromAny(body["resolution"])); resolution != "" {
body["size"] = resolution
}
if size := widthHeightSize(body); size != "" {
body["size"] = size
}
if supportsMultipleOutputs(request, request.ModelType) && body["sequential_image_generation"] == nil {
body["sequential_image_generation"] = "auto"
}
return body
}
func volcesVideoBody(request Request) map[string]any {
body := cleanProviderBody(request.Body)
body["model"] = request.Candidate.ModelName
content := contentItems(body["content"])
if len(content) == 0 {
content = buildVolcesContentFromBody(body)
}
appendMultiShotTimeline(&content)
normalizeVolcesContentRoles(content)
appendVolcesVideoParams(&content, body)
body["content"] = content
stripVolcesVideoConvenienceFields(body)
return body
}
func cleanProviderBody(body map[string]any) map[string]any {
out := cloneBody(body)
for _, key := range []string{
"runMode",
"mode",
"simulation",
"testMode",
"simulationProfile",
"testProfile",
"pollIntervalMs",
"poll_interval_ms",
"pollTimeoutSeconds",
"poll_timeout_seconds",
} {
delete(out, key)
}
return out
}
func buildVolcesContentFromBody(body map[string]any) []map[string]any {
content := make([]map[string]any, 0)
if prompt := firstNonEmptyStringValue(body, "prompt", "input"); prompt != "" {
content = append(content, map[string]any{"type": "text", "text": prompt})
}
appendURLContent := func(kind string, role string, url string) {
if strings.TrimSpace(url) == "" {
return
}
switch kind {
case "image_url":
content = append(content, map[string]any{"type": kind, "role": role, "image_url": map[string]any{"url": strings.TrimSpace(url)}})
case "video_url":
content = append(content, map[string]any{"type": kind, "role": role, "video_url": map[string]any{"url": strings.TrimSpace(url)}})
case "audio_url":
content = append(content, map[string]any{"type": kind, "role": role, "audio_url": map[string]any{"url": strings.TrimSpace(url)}})
}
}
appendURLContent("image_url", "first_frame", firstNonEmptyStringValue(body, "first_frame", "firstFrame"))
appendURLContent("image_url", "last_frame", firstNonEmptyStringValue(body, "last_frame", "lastFrame"))
for _, url := range firstNonEmptyStringListFromAny(body["image"], body["images"], body["image_url"], body["imageUrl"], body["image_urls"], body["imageUrls"], body["reference_image"], body["referenceImage"]) {
appendURLContent("image_url", "reference_image", url)
}
for _, url := range firstNonEmptyStringListFromAny(body["video"], body["video_url"], body["videoUrl"], body["reference_video"], body["referenceVideo"]) {
appendURLContent("video_url", "reference_video", url)
}
for _, url := range firstNonEmptyStringListFromAny(body["audio_url"], body["audioUrl"], body["reference_audio"], body["referenceAudio"]) {
appendURLContent("audio_url", "reference_audio", url)
}
if len(content) == 0 {
content = append(content, map[string]any{"type": "text", "text": ""})
}
return content
}
func stripVolcesVideoConvenienceFields(body map[string]any) {
for _, key := range []string{
"prompt",
"input",
"image",
"images",
"image_url",
"imageUrl",
"image_urls",
"imageUrls",
"reference_image",
"referenceImage",
"first_frame",
"firstFrame",
"last_frame",
"lastFrame",
"video",
"video_url",
"videoUrl",
"reference_video",
"referenceVideo",
"audio_url",
"audioUrl",
"reference_audio",
"referenceAudio",
} {
delete(body, key)
}
}
func contentItems(value any) []map[string]any {
rawItems, ok := value.([]any)
if !ok {
return nil
}
out := make([]map[string]any, 0, len(rawItems))
for _, raw := range rawItems {
item, ok := raw.(map[string]any)
if !ok {
continue
}
copied := map[string]any{}
for key, value := range item {
copied[key] = value
}
out = append(out, copied)
}
return out
}
func normalizeVolcesContentRoles(content []map[string]any) {
for _, item := range content {
itemType := strings.TrimSpace(stringFromAny(item["type"]))
role := strings.TrimSpace(stringFromAny(item["role"]))
switch itemType {
case "image_url":
if role != "first_frame" && role != "last_frame" {
item["role"] = "reference_image"
}
case "video_url":
item["role"] = "reference_video"
case "audio_url":
item["role"] = "reference_audio"
}
}
}
func appendVolcesVideoParams(content *[]map[string]any, body map[string]any) {
textItem := ensureTextContent(content)
current := strings.TrimSpace(stringFromAny(textItem["text"]))
values := []struct {
key string
value any
}{
{"dur", firstPresent(body["duration"], body["dur"])},
{"ratio", firstPresent(body["aspect_ratio"], body["aspectRatio"], body["ratio"])},
{"fps", firstPresent(body["framespersecond"], body["framesPerSecond"], body["fps"])},
{"watermark", firstPresent(body["watermark"], false)},
{"seed", firstPresent(body["seed"], -1)},
{"cf", firstPresent(body["camerafixed"], body["cameraFixed"])},
{"rs", firstPresent(body["resolution"], body["size"])},
}
for _, item := range values {
valueText := volcesParamString(item.value)
if valueText == "" || strings.Contains(current, "--"+item.key) {
continue
}
if current != "" {
current += " "
}
current += "--" + item.key + " " + valueText
}
textItem["text"] = current
}
func appendMultiShotTimeline(content *[]map[string]any) {
type shot struct {
index int
text string
duration float64
}
shots := make([]shot, 0)
items := *content
remaining := items[:0]
for index, item := range items {
if stringFromAny(item["type"]) != "text" {
remaining = append(remaining, item)
continue
}
role := stringFromAny(item["role"])
if role != "shot_prompt" && item["shot_index"] == nil {
remaining = append(remaining, item)
continue
}
text := strings.TrimSpace(stringFromAny(item["text"]))
if text == "" {
continue
}
shotIndex := numericValue(item["shot_index"], float64(index))
shots = append(shots, shot{index: int(math.Floor(shotIndex)), text: text, duration: numericValue(item["duration"], 5)})
}
if len(shots) == 0 {
return
}
*content = remaining
for i := 0; i < len(shots)-1; i++ {
for j := i + 1; j < len(shots); j++ {
if shots[j].index < shots[i].index {
shots[i], shots[j] = shots[j], shots[i]
}
}
}
cursor := 0.0
lines := make([]string, 0, len(shots))
for idx, shot := range shots {
start := cursor
duration := shot.duration
if duration <= 0 {
duration = 5
}
end := start + duration
cursor = end
shotNumber := shot.index + 1
if shotNumber <= 0 {
shotNumber = idx + 1
}
lines = append(lines, fmt.Sprintf("Shot %d, %gs~%gs: %s", shotNumber, start, end, shot.text))
}
textItem := ensureTextContent(content)
current := stringFromAny(textItem["text"])
const prefix = "Additional shot timeline (auto-generated):"
if strings.Contains(current, prefix) {
return
}
separator := ""
if strings.TrimSpace(current) != "" {
separator = "\n\n"
}
textItem["text"] = current + separator + prefix + "\n" + strings.Join(lines, "\n")
}
func ensureTextContent(content *[]map[string]any) map[string]any {
for _, item := range *content {
if stringFromAny(item["type"]) == "text" && item["shot_index"] == nil && stringFromAny(item["role"]) != "shot_prompt" {
return item
}
}
item := map[string]any{"type": "text", "text": ""}
*content = append([]map[string]any{item}, (*content)...)
return item
}
func supportsMultipleOutputs(request Request, capabilityName string) bool {
for _, key := range []string{capabilityName, request.ModelType, "image_generate", "image_edit"} {
if key == "" {
continue
}
capability, _ := request.Candidate.Capabilities[key].(map[string]any)
if boolFromAny(capability["output_multiple_images"]) {
return true
}
}
return false
}
func widthHeightSize(body map[string]any) string {
width := numericValue(body["width"], 0)
height := numericValue(body["height"], 0)
if width <= 0 || height <= 0 {
return ""
}
return fmt.Sprintf("%dx%d", int(math.Round(width)), int(math.Round(height)))
}
func volcesTaskStatus(result map[string]any) string {
return strings.ToLower(strings.TrimSpace(stringFromAny(result["status"])))
}
func volcesTaskErrorCode(result map[string]any) string {
errorObj, _ := result["error"].(map[string]any)
if code := strings.TrimSpace(stringFromAny(errorObj["code"])); code != "" {
return code
}
status := volcesTaskStatus(result)
if status != "" {
return status
}
return "volces_task_failed"
}
func volcesTaskErrorMessage(result map[string]any) string {
errorObj, _ := result["error"].(map[string]any)
if message := strings.TrimSpace(stringFromAny(errorObj["message"])); message != "" {
return message
}
if status := volcesTaskStatus(result); status != "" {
return "volces video task " + status
}
return "volces video task failed"
}
func volcesVideoSuccessResult(request Request, upstreamTaskID string, raw map[string]any) map[string]any {
content, _ := raw["content"].(map[string]any)
videoURL := strings.TrimSpace(stringFromAny(content["video_url"]))
created := intFromAny(raw["created_at"])
if created == 0 {
created = int(nowUnix())
}
data := []any{}
if videoURL != "" {
data = append(data, map[string]any{"url": videoURL, "type": "video"})
}
return map[string]any{
"id": upstreamTaskID,
"object": "video.generation",
"created": created,
"model": request.Candidate.ModelName,
"status": "succeeded",
"upstream_task_id": upstreamTaskID,
"data": data,
"raw": raw,
}
}
func volcesVideoUsage(raw map[string]any) Usage {
usage, _ := raw["usage"].(map[string]any)
output := intFromAny(usage["completion_tokens"])
total := intFromAny(usage["total_tokens"])
if total == 0 {
total = output
}
return Usage{OutputTokens: output, TotalTokens: total}
}
func volcesVideoProgress(request Request, upstreamTaskID string) []Progress {
progress := providerProgress(request)
progress = append(progress, Progress{
Phase: "polling_result",
Progress: 0.9,
Message: "volces video task completed",
Payload: map[string]any{"upstreamTaskId": upstreamTaskID},
})
return progress
}
func volcesPollInterval(request Request) time.Duration {
ms := numericValue(firstPresent(request.Candidate.PlatformConfig["volcesPollIntervalMs"], request.Body["pollIntervalMs"], request.Body["poll_interval_ms"]), 5000)
if ms < 100 {
ms = 100
}
return time.Duration(ms) * time.Millisecond
}
func volcesPollTimeout(request Request) time.Duration {
seconds := numericValue(firstPresent(request.Candidate.PlatformConfig["volcesPollTimeoutSeconds"], request.Body["pollTimeoutSeconds"], request.Body["poll_timeout_seconds"]), 600)
if seconds < 1 {
seconds = 600
}
return time.Duration(seconds) * time.Second
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func firstNonEmptyStringValue(body map[string]any, keys ...string) string {
for _, key := range keys {
if value := strings.TrimSpace(stringFromAny(body[key])); value != "" {
return value
}
}
return ""
}
func stringListFromAny(value any) []string {
switch typed := value.(type) {
case string:
if strings.TrimSpace(typed) == "" {
return nil
}
return []string{typed}
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
if value := strings.TrimSpace(stringFromAny(item)); value != "" {
out = append(out, value)
}
}
return out
case []string:
out := make([]string, 0, len(typed))
for _, item := range typed {
if value := strings.TrimSpace(item); value != "" {
out = append(out, value)
}
}
return out
default:
return nil
}
}
func firstNonEmptyStringListFromAny(values ...any) []string {
for _, value := range values {
items := stringListFromAny(value)
if len(items) > 0 {
return items
}
}
return nil
}
func volcesParamString(value any) string {
switch typed := value.(type) {
case nil:
return ""
case string:
return strings.TrimSpace(typed)
case bool:
if typed {
return "true"
}
return "false"
case int:
return fmt.Sprintf("%d", typed)
case int64:
return fmt.Sprintf("%d", typed)
case float64:
if math.Mod(typed, 1) == 0 {
return fmt.Sprintf("%d", int64(typed))
}
return fmt.Sprintf("%g", typed)
default:
return fmt.Sprintf("%v", typed)
}
}
func numericValue(value any, fallback float64) float64 {
switch typed := value.(type) {
case int:
return float64(typed)
case int64:
return float64(typed)
case float64:
return typed
case string:
var parsed float64
if _, err := fmt.Sscanf(strings.TrimSpace(typed), "%f", &parsed); err == nil {
return parsed
}
return fallback
default:
return fallback
}
}
func boolFromAny(value any) bool {
switch typed := value.(type) {
case bool:
return typed
case string:
normalized := strings.ToLower(strings.TrimSpace(typed))
return normalized == "true" || normalized == "1"
case float64:
return typed == 1
case int:
return typed == 1
default:
return false
}
}