easyai-ai-gateway/apps/api/internal/store/cloned_voices.go

309 lines
9.3 KiB
Go

package store
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
)
type ClonedVoice struct {
ID string `json:"id"`
GatewayUserID string `json:"gatewayUserId,omitempty"`
UserID string `json:"userId"`
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
TenantID string `json:"tenantId,omitempty"`
TenantKey string `json:"tenantKey,omitempty"`
Provider string `json:"provider"`
PlatformID string `json:"platformId,omitempty"`
PlatformName string `json:"platformName,omitempty"`
PlatformModelID string `json:"platformModelId,omitempty"`
Model string `json:"model,omitempty"`
PreviewModel string `json:"previewModel,omitempty"`
VoiceID string `json:"voiceId"`
DisplayName string `json:"displayName,omitempty"`
DemoAudioURL string `json:"demoAudioUrl,omitempty"`
Status string `json:"status"`
ExpiresAt string `json:"expiresAt,omitempty"`
LastUsedAt string `json:"lastUsedAt,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type ClonedVoiceInput struct {
GatewayUserID string
UserID string
GatewayTenantID string
TenantID string
TenantKey string
Provider string
PlatformID string
PlatformModelID string
SourceTaskID string
SourceAttemptID string
Model string
PreviewModel string
VoiceID string
DisplayName string
DemoAudioURL string
Status string
ExpiresAt *time.Time
Metadata map[string]any
}
const clonedVoiceColumns = `
v.id::text, COALESCE(v.gateway_user_id::text, ''), v.user_id,
COALESCE(v.gateway_tenant_id::text, ''), COALESCE(v.tenant_id, ''), COALESCE(v.tenant_key, ''),
v.provider, COALESCE(v.platform_id::text, ''), COALESCE(p.name, ''),
COALESCE(v.platform_model_id::text, ''), COALESCE(v.model, ''), COALESCE(v.preview_model, ''),
v.voice_id, COALESCE(v.display_name, ''), COALESCE(v.demo_audio_url, ''), v.status,
COALESCE(v.expires_at::text, ''), COALESCE(v.last_used_at::text, ''),
COALESCE(v.metadata, '{}'::jsonb), v.created_at, v.updated_at`
func (s *Store) UpsertClonedVoice(ctx context.Context, input ClonedVoiceInput) (ClonedVoice, error) {
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
status := strings.TrimSpace(input.Status)
if status == "" {
status = "active"
}
return scanClonedVoice(s.pool.QueryRow(ctx, `
WITH upsert AS (
INSERT INTO gateway_cloned_voices (
gateway_user_id, user_id, gateway_tenant_id, tenant_id, tenant_key,
provider, platform_id, platform_model_id, source_task_id, source_attempt_id,
model, preview_model, voice_id, display_name, demo_audio_url, status, expires_at, metadata
)
VALUES (
NULLIF($1, '')::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, ''), NULLIF($5, ''),
$6, NULLIF($7, '')::uuid, NULLIF($8, '')::uuid, NULLIF($9, '')::uuid, NULLIF($10, '')::uuid,
$11, $12, $13, $14, $15, $16, $17, $18::jsonb
)
ON CONFLICT (platform_id, voice_id) WHERE platform_id IS NOT NULL AND voice_id <> ''
DO UPDATE SET
gateway_user_id = EXCLUDED.gateway_user_id,
user_id = EXCLUDED.user_id,
gateway_tenant_id = EXCLUDED.gateway_tenant_id,
tenant_id = EXCLUDED.tenant_id,
tenant_key = EXCLUDED.tenant_key,
provider = EXCLUDED.provider,
platform_model_id = EXCLUDED.platform_model_id,
source_task_id = EXCLUDED.source_task_id,
source_attempt_id = EXCLUDED.source_attempt_id,
model = EXCLUDED.model,
preview_model = EXCLUDED.preview_model,
display_name = EXCLUDED.display_name,
demo_audio_url = EXCLUDED.demo_audio_url,
status = EXCLUDED.status,
expires_at = EXCLUDED.expires_at,
metadata = gateway_cloned_voices.metadata || EXCLUDED.metadata,
updated_at = now()
RETURNING *
)
SELECT `+clonedVoiceColumns+`
FROM upsert v
LEFT JOIN integration_platforms p ON p.id = v.platform_id`,
input.GatewayUserID,
input.UserID,
input.GatewayTenantID,
input.TenantID,
input.TenantKey,
input.Provider,
input.PlatformID,
input.PlatformModelID,
input.SourceTaskID,
input.SourceAttemptID,
input.Model,
input.PreviewModel,
input.VoiceID,
input.DisplayName,
input.DemoAudioURL,
status,
input.ExpiresAt,
string(metadata),
))
}
func (s *Store) ListClonedVoices(ctx context.Context, user *auth.User) ([]ClonedVoice, error) {
gatewayUserID, userID := clonedVoiceUserKeys(user)
rows, err := s.pool.Query(ctx, `
SELECT `+clonedVoiceColumns+`
FROM gateway_cloned_voices v
LEFT JOIN integration_platforms p ON p.id = v.platform_id
WHERE (
(
NULLIF($1, '')::uuid IS NOT NULL
AND v.gateway_user_id = NULLIF($1, '')::uuid
)
OR (
NULLIF($2, '') IS NOT NULL
AND v.user_id = $2
)
)
AND v.status <> 'deleted'
ORDER BY v.created_at DESC`, gatewayUserID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]ClonedVoice, 0)
for rows.Next() {
item, err := scanClonedVoice(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) FindClonedVoiceForUser(ctx context.Context, user *auth.User, clonedVoiceID string, voiceID string) (ClonedVoice, bool, error) {
gatewayUserID, userID := clonedVoiceUserKeys(user)
clonedVoiceID = strings.TrimSpace(clonedVoiceID)
voiceID = strings.TrimSpace(voiceID)
if clonedVoiceID == "" && voiceID == "" {
return ClonedVoice{}, false, nil
}
item, err := scanClonedVoice(s.pool.QueryRow(ctx, `
SELECT `+clonedVoiceColumns+`
FROM gateway_cloned_voices v
LEFT JOIN integration_platforms p ON p.id = v.platform_id
WHERE (
(
NULLIF($1, '')::uuid IS NOT NULL
AND v.gateway_user_id = NULLIF($1, '')::uuid
)
OR (
NULLIF($2, '') IS NOT NULL
AND v.user_id = $2
)
)
AND (
(NULLIF($3, '')::uuid IS NOT NULL AND v.id = NULLIF($3, '')::uuid)
OR (NULLIF($4, '') IS NOT NULL AND v.voice_id = $4)
)
AND v.status NOT IN ('deleted', 'failed')
ORDER BY CASE WHEN NULLIF($3, '')::uuid IS NOT NULL AND v.id = NULLIF($3, '')::uuid THEN 0 ELSE 1 END,
v.created_at DESC
LIMIT 1`, gatewayUserID, userID, clonedVoiceID, voiceID))
if err != nil {
if IsNotFound(err) {
return ClonedVoice{}, false, nil
}
return ClonedVoice{}, false, err
}
return item, true, nil
}
func (s *Store) DeleteClonedVoiceForUser(ctx context.Context, user *auth.User, clonedVoiceID string, voiceID string) (ClonedVoice, bool, error) {
gatewayUserID, userID := clonedVoiceUserKeys(user)
clonedVoiceID = strings.TrimSpace(clonedVoiceID)
voiceID = strings.TrimSpace(voiceID)
if clonedVoiceID == "" && voiceID == "" {
return ClonedVoice{}, false, nil
}
item, err := scanClonedVoice(s.pool.QueryRow(ctx, `
WITH updated AS (
UPDATE gateway_cloned_voices v
SET status = 'deleted', updated_at = now()
WHERE (
(
NULLIF($1, '')::uuid IS NOT NULL
AND v.gateway_user_id = NULLIF($1, '')::uuid
)
OR (
NULLIF($2, '') IS NOT NULL
AND v.user_id = $2
)
)
AND (
(NULLIF($3, '')::uuid IS NOT NULL AND v.id = NULLIF($3, '')::uuid)
OR (NULLIF($4, '') IS NOT NULL AND v.voice_id = $4)
)
AND v.status <> 'deleted'
RETURNING *
)
SELECT `+clonedVoiceColumns+`
FROM updated v
LEFT JOIN integration_platforms p ON p.id = v.platform_id`, gatewayUserID, userID, clonedVoiceID, voiceID))
if err != nil {
if IsNotFound(err) {
return ClonedVoice{}, false, nil
}
return ClonedVoice{}, false, err
}
return item, true, nil
}
func (s *Store) TouchClonedVoiceUsage(ctx context.Context, clonedVoiceID string) error {
if strings.TrimSpace(clonedVoiceID) == "" {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE gateway_cloned_voices
SET last_used_at = now(), expires_at = now() + interval '7 days', updated_at = now()
WHERE id = $1::uuid`, clonedVoiceID)
return err
}
func (s *Store) MarkClonedVoiceStatus(ctx context.Context, clonedVoiceID string, status string) error {
if strings.TrimSpace(clonedVoiceID) == "" || strings.TrimSpace(status) == "" {
return nil
}
_, err := s.pool.Exec(ctx, `
UPDATE gateway_cloned_voices
SET status = $2, updated_at = now()
WHERE id = $1::uuid`, clonedVoiceID, status)
return err
}
func clonedVoiceUserKeys(user *auth.User) (string, string) {
if user == nil {
return "", ""
}
gatewayUserID := strings.TrimSpace(user.GatewayUserID)
if gatewayUserID == "" && user.Source == "gateway" {
gatewayUserID = strings.TrimSpace(user.ID)
}
userID := strings.TrimSpace(user.ID)
return gatewayUserID, userID
}
type clonedVoiceScanner interface {
Scan(dest ...any) error
}
func scanClonedVoice(scanner clonedVoiceScanner) (ClonedVoice, error) {
var item ClonedVoice
var metadata []byte
if err := scanner.Scan(
&item.ID,
&item.GatewayUserID,
&item.UserID,
&item.GatewayTenantID,
&item.TenantID,
&item.TenantKey,
&item.Provider,
&item.PlatformID,
&item.PlatformName,
&item.PlatformModelID,
&item.Model,
&item.PreviewModel,
&item.VoiceID,
&item.DisplayName,
&item.DemoAudioURL,
&item.Status,
&item.ExpiresAt,
&item.LastUsedAt,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return ClonedVoice{}, err
}
item.Metadata = decodeObject(metadata)
return item, nil
}