easyai-ai-gateway/apps/api/internal/runner/voice_clone.go

285 lines
11 KiB
Go

package runner
import (
"context"
"fmt"
"strings"
"time"
"unicode"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type clonedVoiceBinding struct {
Voice store.ClonedVoice
Found bool
Explicit bool
}
type DeletedClonedVoiceResult struct {
Voice store.ClonedVoice `json:"voice"`
Upstream map[string]any `json:"upstream,omitempty"`
RequestID string `json:"requestId,omitempty"`
}
func validateVoiceCloneRequest(body map[string]any) error {
voiceID := firstNonEmptyString(stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
if !validMiniMaxVoiceID(voiceID) {
return fmt.Errorf("voice_id must be 8-256 chars, start with an English letter, contain only letters, digits, '-' or '_', and not end with '-' or '_'")
}
if body["file_id"] == nil && body["fileId"] == nil &&
stringFromAny(body["audio"]) == "" &&
stringFromAny(body["file"]) == "" &&
stringFromAny(body["source_audio"]) == "" &&
stringFromAny(body["sourceAudio"]) == "" &&
stringFromMap(body, "audio_url") == "" &&
stringFromMap(body, "audioUrl") == "" {
return fmt.Errorf("file_id or audio is required")
}
if hasVoiceClonePromptAudio(body) && firstNonEmptyString(stringFromMap(body, "prompt_text"), stringFromMap(body, "promptText")) == "" {
return fmt.Errorf("prompt_text is required when prompt audio is provided")
}
return nil
}
func validMiniMaxVoiceID(value string) bool {
value = strings.TrimSpace(value)
if len(value) < 8 || len(value) > 256 {
return false
}
for index, r := range value {
if index == 0 && !isASCIILetter(r) {
return false
}
if !(isASCIILetter(r) || unicode.IsDigit(r) || r == '-' || r == '_') {
return false
}
}
return !strings.HasSuffix(value, "-") && !strings.HasSuffix(value, "_")
}
func isASCIILetter(r rune) bool {
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
}
func hasVoiceClonePromptAudio(body map[string]any) bool {
return body["prompt_file_id"] != nil ||
body["promptFileId"] != nil ||
stringFromAny(body["prompt_audio"]) != "" ||
stringFromAny(body["promptAudio"]) != "" ||
stringFromMap(body, "prompt_audio_url") != "" ||
stringFromMap(body, "promptAudioUrl") != ""
}
func (s *Service) resolveClonedVoiceBinding(ctx context.Context, user *auth.User, kind string, body map[string]any) (map[string]any, clonedVoiceBinding, error) {
if kind != "speech.generations" {
return body, clonedVoiceBinding{}, nil
}
clonedVoiceID := firstNonEmptyString(stringFromMap(body, "cloned_voice_id"), stringFromMap(body, "clonedVoiceId"))
voiceID := firstNonEmptyString(stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
if clonedVoiceID == "" && voiceID == "" {
return body, clonedVoiceBinding{}, nil
}
if clonedVoiceID != "" && !looksLikeUUID(clonedVoiceID) {
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "bad_request", Message: "cloned_voice_id must be a UUID", StatusCode: 400, Retryable: false}
}
voice, found, err := s.store.FindClonedVoiceForUser(ctx, user, clonedVoiceID, voiceID)
if err != nil {
return body, clonedVoiceBinding{}, err
}
if !found {
if clonedVoiceID != "" {
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "cloned_voice_not_found", Message: "cloned voice not found", StatusCode: 404, Retryable: false}
}
return body, clonedVoiceBinding{}, nil
}
if strings.TrimSpace(voice.Status) != "" && voice.Status != "active" {
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "cloned_voice_unavailable", Message: "cloned voice is not active", StatusCode: 400, Retryable: false}
}
if voice.ExpiresAt != "" {
if expiresAt, err := time.Parse(time.RFC3339Nano, voice.ExpiresAt); err == nil && !expiresAt.After(time.Now()) {
_ = s.store.MarkClonedVoiceStatus(context.WithoutCancel(ctx), voice.ID, "expired")
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "cloned_voice_expired", Message: "cloned voice has expired", StatusCode: 400, Retryable: false}
}
}
out := cloneMap(body)
out["voice_id"] = voice.VoiceID
out["cloned_voice_id"] = voice.ID
return out, clonedVoiceBinding{Voice: voice, Found: true, Explicit: clonedVoiceID != ""}, nil
}
func filterCandidatesByClonedVoiceBinding(candidates []store.RuntimeModelCandidate, binding clonedVoiceBinding) ([]store.RuntimeModelCandidate, error) {
if !binding.Found {
return candidates, nil
}
filtered := make([]store.RuntimeModelCandidate, 0, len(candidates))
preferred := make([]store.RuntimeModelCandidate, 0, 1)
for _, candidate := range candidates {
if strings.TrimSpace(candidate.PlatformID) != binding.Voice.PlatformID {
continue
}
if binding.Voice.PlatformModelID != "" && candidate.PlatformModelID == binding.Voice.PlatformModelID {
preferred = append(preferred, candidate)
continue
}
filtered = append(filtered, candidate)
}
if len(preferred) > 0 {
filtered = append(preferred, filtered...)
}
if len(filtered) == 0 {
return nil, &store.ModelCandidateUnavailableError{
Code: "cloned_voice_platform_unavailable",
Message: "cloned voice is bound to a platform that has no enabled candidate for the requested speech model",
Details: map[string]any{
"clonedVoiceId": binding.Voice.ID,
"voiceId": binding.Voice.VoiceID,
"platformId": binding.Voice.PlatformID,
"platformModelId": binding.Voice.PlatformModelID,
},
}
}
return filtered, nil
}
func (s *Service) persistVoiceCloneResult(ctx context.Context, task store.GatewayTask, user *auth.User, candidate store.RuntimeModelCandidate, attemptID string, body map[string]any, result map[string]any) (store.ClonedVoice, error) {
voiceID := firstNonEmptyString(stringFromAny(result["voice_id"]), stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
demoAudioURL := firstNonEmptyString(stringFromAny(result["demo_audio"]), firstAudioURLFromResult(result))
previewModel := firstNonEmptyString(stringFromMap(body, "preview_model"), stringFromMap(body, "previewModel"), stringFromAny(result["preview_model"]))
expiresAt := time.Now().Add(7 * 24 * time.Hour)
return s.store.UpsertClonedVoice(ctx, store.ClonedVoiceInput{
GatewayUserID: task.GatewayUserID,
UserID: task.UserID,
GatewayTenantID: task.GatewayTenantID,
TenantID: task.TenantID,
TenantKey: task.TenantKey,
Provider: candidate.Provider,
PlatformID: candidate.PlatformID,
PlatformModelID: candidate.PlatformModelID,
SourceTaskID: task.ID,
SourceAttemptID: attemptID,
Model: task.Model,
PreviewModel: previewModel,
VoiceID: voiceID,
DisplayName: firstNonEmptyString(stringFromMap(body, "display_name"), stringFromMap(body, "displayName"), voiceID),
DemoAudioURL: demoAudioURL,
Status: "active",
ExpiresAt: &expiresAt,
Metadata: map[string]any{
"request": map[string]any{
"textValidation": body["text_validation"],
"languageBoost": body["language_boost"],
"needNoiseReduction": body["need_noise_reduction"],
"needVolumeNormalization": body["need_volume_normalization"],
"aigcWatermark": body["aigc_watermark"],
},
"rawData": result["raw_data"],
},
})
}
func (s *Service) touchClonedVoiceUsage(ctx context.Context, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate) {
clonedVoiceID := firstNonEmptyString(stringFromMap(body, "cloned_voice_id"), stringFromMap(body, "clonedVoiceId"))
voiceID := firstNonEmptyString(stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
voice, found, err := s.store.FindClonedVoiceForUser(ctx, user, clonedVoiceID, voiceID)
if err != nil || !found || voice.PlatformID != candidate.PlatformID {
return
}
_ = s.store.TouchClonedVoiceUsage(ctx, voice.ID)
}
func (s *Service) DeleteClonedVoice(ctx context.Context, user *auth.User, rawID string) (DeletedClonedVoiceResult, error) {
id := strings.TrimSpace(rawID)
if id == "" {
return DeletedClonedVoiceResult{}, &clients.ClientError{Code: "bad_request", Message: "voice id is required", StatusCode: 400, Retryable: false}
}
clonedVoiceID := ""
voiceID := id
if looksLikeUUID(id) {
clonedVoiceID = id
voiceID = ""
}
voice, found, err := s.store.FindClonedVoiceForUser(ctx, user, clonedVoiceID, voiceID)
if err != nil {
return DeletedClonedVoiceResult{}, err
}
if !found {
return DeletedClonedVoiceResult{}, &clients.ClientError{Code: "cloned_voice_not_found", Message: "cloned voice not found", StatusCode: 404, Retryable: false}
}
candidate, ok, err := s.store.GetRuntimeModelCandidateForVoiceCloneDeletion(ctx, voice.PlatformModelID, voice.PlatformID)
if err != nil {
return DeletedClonedVoiceResult{}, err
}
if !ok {
return DeletedClonedVoiceResult{}, &clients.ClientError{Code: "cloned_voice_platform_unavailable", Message: "cloned voice platform binding is unavailable", StatusCode: 400, Retryable: false}
}
requestHTTPClient, err := s.httpClientForCandidate(candidate, false)
if err != nil {
return DeletedClonedVoiceResult{}, err
}
deleter, ok := s.clientFor(candidate, false).(clients.VoiceCloneDeleter)
if !ok {
return DeletedClonedVoiceResult{}, &clients.ClientError{Code: "unsupported_operation", Message: "voice clone deletion is not supported by this provider", StatusCode: 400, Retryable: false}
}
upstream, err := deleter.DeleteVoiceClone(ctx, clients.VoiceCloneDeleteRequest{
VoiceID: voice.VoiceID,
Candidate: candidate,
HTTPClient: requestHTTPClient,
})
if err != nil {
return DeletedClonedVoiceResult{}, err
}
deleted, found, err := s.store.DeleteClonedVoiceForUser(ctx, user, voice.ID, voice.VoiceID)
if err != nil {
return DeletedClonedVoiceResult{}, err
}
if !found {
deleted = voice
deleted.Status = "deleted"
}
return DeletedClonedVoiceResult{
Voice: deleted,
Upstream: upstream.Result,
RequestID: upstream.RequestID,
}, nil
}
func firstAudioURLFromResult(result map[string]any) string {
items, _ := result["data"].([]any)
for _, raw := range items {
item, _ := raw.(map[string]any)
if item == nil {
continue
}
if itemType := strings.ToLower(strings.TrimSpace(stringFromAny(item["type"]))); itemType != "" && itemType != "audio" {
continue
}
if url := stringFromAny(item["url"]); url != "" {
return url
}
}
return ""
}
func looksLikeUUID(value string) bool {
value = strings.TrimSpace(value)
if len(value) != 36 {
return false
}
for index, r := range value {
switch index {
case 8, 13, 18, 23:
if r != '-' {
return false
}
default:
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
return false
}
}
}
return true
}