223 lines
8.3 KiB
Go
223 lines
8.3 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
type clonedVoiceBinding struct {
|
|
Voice store.ClonedVoice
|
|
Found bool
|
|
Explicit bool
|
|
}
|
|
|
|
func validateVoiceCloneRequest(body map[string]any) error {
|
|
voiceID := firstNonEmptyString(stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
|
|
if !validMiniMaxVoiceID(voiceID) {
|
|
return fmt.Errorf("voice_id must be 8-256 chars, start with an English letter, contain only letters, digits, '-' or '_', and not end with '-' or '_'")
|
|
}
|
|
if body["file_id"] == nil && body["fileId"] == nil &&
|
|
stringFromAny(body["audio"]) == "" &&
|
|
stringFromAny(body["file"]) == "" &&
|
|
stringFromAny(body["source_audio"]) == "" &&
|
|
stringFromAny(body["sourceAudio"]) == "" &&
|
|
stringFromMap(body, "audio_url") == "" &&
|
|
stringFromMap(body, "audioUrl") == "" {
|
|
return fmt.Errorf("file_id or audio is required")
|
|
}
|
|
if hasVoiceClonePromptAudio(body) && firstNonEmptyString(stringFromMap(body, "prompt_text"), stringFromMap(body, "promptText")) == "" {
|
|
return fmt.Errorf("prompt_text is required when prompt audio is provided")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validMiniMaxVoiceID(value string) bool {
|
|
value = strings.TrimSpace(value)
|
|
if len(value) < 8 || len(value) > 256 {
|
|
return false
|
|
}
|
|
for index, r := range value {
|
|
if index == 0 && !isASCIILetter(r) {
|
|
return false
|
|
}
|
|
if !(isASCIILetter(r) || unicode.IsDigit(r) || r == '-' || r == '_') {
|
|
return false
|
|
}
|
|
}
|
|
return !strings.HasSuffix(value, "-") && !strings.HasSuffix(value, "_")
|
|
}
|
|
|
|
func isASCIILetter(r rune) bool {
|
|
return (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
|
|
}
|
|
|
|
func hasVoiceClonePromptAudio(body map[string]any) bool {
|
|
return body["prompt_file_id"] != nil ||
|
|
body["promptFileId"] != nil ||
|
|
stringFromAny(body["prompt_audio"]) != "" ||
|
|
stringFromAny(body["promptAudio"]) != "" ||
|
|
stringFromMap(body, "prompt_audio_url") != "" ||
|
|
stringFromMap(body, "promptAudioUrl") != ""
|
|
}
|
|
|
|
func (s *Service) resolveClonedVoiceBinding(ctx context.Context, user *auth.User, kind string, body map[string]any) (map[string]any, clonedVoiceBinding, error) {
|
|
if kind != "speech.generations" {
|
|
return body, clonedVoiceBinding{}, nil
|
|
}
|
|
clonedVoiceID := firstNonEmptyString(stringFromMap(body, "cloned_voice_id"), stringFromMap(body, "clonedVoiceId"))
|
|
voiceID := firstNonEmptyString(stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
|
|
if clonedVoiceID == "" && voiceID == "" {
|
|
return body, clonedVoiceBinding{}, nil
|
|
}
|
|
if clonedVoiceID != "" && !looksLikeUUID(clonedVoiceID) {
|
|
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "bad_request", Message: "cloned_voice_id must be a UUID", StatusCode: 400, Retryable: false}
|
|
}
|
|
voice, found, err := s.store.FindClonedVoiceForUser(ctx, user, clonedVoiceID, voiceID)
|
|
if err != nil {
|
|
return body, clonedVoiceBinding{}, err
|
|
}
|
|
if !found {
|
|
if clonedVoiceID != "" {
|
|
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "cloned_voice_not_found", Message: "cloned voice not found", StatusCode: 404, Retryable: false}
|
|
}
|
|
return body, clonedVoiceBinding{}, nil
|
|
}
|
|
if strings.TrimSpace(voice.Status) != "" && voice.Status != "active" {
|
|
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "cloned_voice_unavailable", Message: "cloned voice is not active", StatusCode: 400, Retryable: false}
|
|
}
|
|
if voice.ExpiresAt != "" {
|
|
if expiresAt, err := time.Parse(time.RFC3339Nano, voice.ExpiresAt); err == nil && !expiresAt.After(time.Now()) {
|
|
_ = s.store.MarkClonedVoiceStatus(context.WithoutCancel(ctx), voice.ID, "expired")
|
|
return body, clonedVoiceBinding{}, &clients.ClientError{Code: "cloned_voice_expired", Message: "cloned voice has expired", StatusCode: 400, Retryable: false}
|
|
}
|
|
}
|
|
out := cloneMap(body)
|
|
out["voice_id"] = voice.VoiceID
|
|
out["cloned_voice_id"] = voice.ID
|
|
return out, clonedVoiceBinding{Voice: voice, Found: true, Explicit: clonedVoiceID != ""}, nil
|
|
}
|
|
|
|
func filterCandidatesByClonedVoiceBinding(candidates []store.RuntimeModelCandidate, binding clonedVoiceBinding) ([]store.RuntimeModelCandidate, error) {
|
|
if !binding.Found {
|
|
return candidates, nil
|
|
}
|
|
filtered := make([]store.RuntimeModelCandidate, 0, len(candidates))
|
|
preferred := make([]store.RuntimeModelCandidate, 0, 1)
|
|
for _, candidate := range candidates {
|
|
if strings.TrimSpace(candidate.PlatformID) != binding.Voice.PlatformID {
|
|
continue
|
|
}
|
|
if binding.Voice.PlatformModelID != "" && candidate.PlatformModelID == binding.Voice.PlatformModelID {
|
|
preferred = append(preferred, candidate)
|
|
continue
|
|
}
|
|
filtered = append(filtered, candidate)
|
|
}
|
|
if len(preferred) > 0 {
|
|
filtered = append(preferred, filtered...)
|
|
}
|
|
if len(filtered) == 0 {
|
|
return nil, &store.ModelCandidateUnavailableError{
|
|
Code: "cloned_voice_platform_unavailable",
|
|
Message: "cloned voice is bound to a platform that has no enabled candidate for the requested speech model",
|
|
Details: map[string]any{
|
|
"clonedVoiceId": binding.Voice.ID,
|
|
"voiceId": binding.Voice.VoiceID,
|
|
"platformId": binding.Voice.PlatformID,
|
|
"platformModelId": binding.Voice.PlatformModelID,
|
|
},
|
|
}
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
func (s *Service) persistVoiceCloneResult(ctx context.Context, task store.GatewayTask, user *auth.User, candidate store.RuntimeModelCandidate, attemptID string, body map[string]any, result map[string]any) (store.ClonedVoice, error) {
|
|
voiceID := firstNonEmptyString(stringFromAny(result["voice_id"]), stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
|
|
demoAudioURL := firstNonEmptyString(stringFromAny(result["demo_audio"]), firstAudioURLFromResult(result))
|
|
previewModel := firstNonEmptyString(stringFromMap(body, "preview_model"), stringFromMap(body, "previewModel"), stringFromAny(result["preview_model"]))
|
|
expiresAt := time.Now().Add(7 * 24 * time.Hour)
|
|
return s.store.UpsertClonedVoice(ctx, store.ClonedVoiceInput{
|
|
GatewayUserID: task.GatewayUserID,
|
|
UserID: task.UserID,
|
|
GatewayTenantID: task.GatewayTenantID,
|
|
TenantID: task.TenantID,
|
|
TenantKey: task.TenantKey,
|
|
Provider: candidate.Provider,
|
|
PlatformID: candidate.PlatformID,
|
|
PlatformModelID: candidate.PlatformModelID,
|
|
SourceTaskID: task.ID,
|
|
SourceAttemptID: attemptID,
|
|
Model: task.Model,
|
|
PreviewModel: previewModel,
|
|
VoiceID: voiceID,
|
|
DisplayName: firstNonEmptyString(stringFromMap(body, "display_name"), stringFromMap(body, "displayName"), voiceID),
|
|
DemoAudioURL: demoAudioURL,
|
|
Status: "active",
|
|
ExpiresAt: &expiresAt,
|
|
Metadata: map[string]any{
|
|
"request": map[string]any{
|
|
"textValidation": body["text_validation"],
|
|
"languageBoost": body["language_boost"],
|
|
"needNoiseReduction": body["need_noise_reduction"],
|
|
"needVolumeNormalization": body["need_volume_normalization"],
|
|
"aigcWatermark": body["aigc_watermark"],
|
|
},
|
|
"rawData": result["raw_data"],
|
|
},
|
|
})
|
|
}
|
|
|
|
func (s *Service) touchClonedVoiceUsage(ctx context.Context, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate) {
|
|
clonedVoiceID := firstNonEmptyString(stringFromMap(body, "cloned_voice_id"), stringFromMap(body, "clonedVoiceId"))
|
|
voiceID := firstNonEmptyString(stringFromMap(body, "voice_id"), stringFromMap(body, "voiceId"))
|
|
voice, found, err := s.store.FindClonedVoiceForUser(ctx, user, clonedVoiceID, voiceID)
|
|
if err != nil || !found || voice.PlatformID != candidate.PlatformID {
|
|
return
|
|
}
|
|
_ = s.store.TouchClonedVoiceUsage(ctx, voice.ID)
|
|
}
|
|
|
|
func firstAudioURLFromResult(result map[string]any) string {
|
|
items, _ := result["data"].([]any)
|
|
for _, raw := range items {
|
|
item, _ := raw.(map[string]any)
|
|
if item == nil {
|
|
continue
|
|
}
|
|
if itemType := strings.ToLower(strings.TrimSpace(stringFromAny(item["type"]))); itemType != "" && itemType != "audio" {
|
|
continue
|
|
}
|
|
if url := stringFromAny(item["url"]); url != "" {
|
|
return url
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func looksLikeUUID(value string) bool {
|
|
value = strings.TrimSpace(value)
|
|
if len(value) != 36 {
|
|
return false
|
|
}
|
|
for index, r := range value {
|
|
switch index {
|
|
case 8, 13, 18, 23:
|
|
if r != '-' {
|
|
return false
|
|
}
|
|
default:
|
|
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|