309 lines
9.3 KiB
Go
309 lines
9.3 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
)
|
|
|
|
type ClonedVoice struct {
|
|
ID string `json:"id"`
|
|
GatewayUserID string `json:"gatewayUserId,omitempty"`
|
|
UserID string `json:"userId"`
|
|
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
|
TenantID string `json:"tenantId,omitempty"`
|
|
TenantKey string `json:"tenantKey,omitempty"`
|
|
Provider string `json:"provider"`
|
|
PlatformID string `json:"platformId,omitempty"`
|
|
PlatformName string `json:"platformName,omitempty"`
|
|
PlatformModelID string `json:"platformModelId,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
PreviewModel string `json:"previewModel,omitempty"`
|
|
VoiceID string `json:"voiceId"`
|
|
DisplayName string `json:"displayName,omitempty"`
|
|
DemoAudioURL string `json:"demoAudioUrl,omitempty"`
|
|
Status string `json:"status"`
|
|
ExpiresAt string `json:"expiresAt,omitempty"`
|
|
LastUsedAt string `json:"lastUsedAt,omitempty"`
|
|
Metadata map[string]any `json:"metadata,omitempty"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
UpdatedAt time.Time `json:"updatedAt"`
|
|
}
|
|
|
|
type ClonedVoiceInput struct {
|
|
GatewayUserID string
|
|
UserID string
|
|
GatewayTenantID string
|
|
TenantID string
|
|
TenantKey string
|
|
Provider string
|
|
PlatformID string
|
|
PlatformModelID string
|
|
SourceTaskID string
|
|
SourceAttemptID string
|
|
Model string
|
|
PreviewModel string
|
|
VoiceID string
|
|
DisplayName string
|
|
DemoAudioURL string
|
|
Status string
|
|
ExpiresAt *time.Time
|
|
Metadata map[string]any
|
|
}
|
|
|
|
const clonedVoiceColumns = `
|
|
v.id::text, COALESCE(v.gateway_user_id::text, ''), v.user_id,
|
|
COALESCE(v.gateway_tenant_id::text, ''), COALESCE(v.tenant_id, ''), COALESCE(v.tenant_key, ''),
|
|
v.provider, COALESCE(v.platform_id::text, ''), COALESCE(p.name, ''),
|
|
COALESCE(v.platform_model_id::text, ''), COALESCE(v.model, ''), COALESCE(v.preview_model, ''),
|
|
v.voice_id, COALESCE(v.display_name, ''), COALESCE(v.demo_audio_url, ''), v.status,
|
|
COALESCE(v.expires_at::text, ''), COALESCE(v.last_used_at::text, ''),
|
|
COALESCE(v.metadata, '{}'::jsonb), v.created_at, v.updated_at`
|
|
|
|
func (s *Store) UpsertClonedVoice(ctx context.Context, input ClonedVoiceInput) (ClonedVoice, error) {
|
|
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
|
|
status := strings.TrimSpace(input.Status)
|
|
if status == "" {
|
|
status = "active"
|
|
}
|
|
return scanClonedVoice(s.pool.QueryRow(ctx, `
|
|
WITH upsert AS (
|
|
INSERT INTO gateway_cloned_voices (
|
|
gateway_user_id, user_id, gateway_tenant_id, tenant_id, tenant_key,
|
|
provider, platform_id, platform_model_id, source_task_id, source_attempt_id,
|
|
model, preview_model, voice_id, display_name, demo_audio_url, status, expires_at, metadata
|
|
)
|
|
VALUES (
|
|
NULLIF($1, '')::uuid, $2, NULLIF($3, '')::uuid, NULLIF($4, ''), NULLIF($5, ''),
|
|
$6, NULLIF($7, '')::uuid, NULLIF($8, '')::uuid, NULLIF($9, '')::uuid, NULLIF($10, '')::uuid,
|
|
$11, $12, $13, $14, $15, $16, $17, $18::jsonb
|
|
)
|
|
ON CONFLICT (platform_id, voice_id) WHERE platform_id IS NOT NULL AND voice_id <> ''
|
|
DO UPDATE SET
|
|
gateway_user_id = EXCLUDED.gateway_user_id,
|
|
user_id = EXCLUDED.user_id,
|
|
gateway_tenant_id = EXCLUDED.gateway_tenant_id,
|
|
tenant_id = EXCLUDED.tenant_id,
|
|
tenant_key = EXCLUDED.tenant_key,
|
|
provider = EXCLUDED.provider,
|
|
platform_model_id = EXCLUDED.platform_model_id,
|
|
source_task_id = EXCLUDED.source_task_id,
|
|
source_attempt_id = EXCLUDED.source_attempt_id,
|
|
model = EXCLUDED.model,
|
|
preview_model = EXCLUDED.preview_model,
|
|
display_name = EXCLUDED.display_name,
|
|
demo_audio_url = EXCLUDED.demo_audio_url,
|
|
status = EXCLUDED.status,
|
|
expires_at = EXCLUDED.expires_at,
|
|
metadata = gateway_cloned_voices.metadata || EXCLUDED.metadata,
|
|
updated_at = now()
|
|
RETURNING *
|
|
)
|
|
SELECT `+clonedVoiceColumns+`
|
|
FROM upsert v
|
|
LEFT JOIN integration_platforms p ON p.id = v.platform_id`,
|
|
input.GatewayUserID,
|
|
input.UserID,
|
|
input.GatewayTenantID,
|
|
input.TenantID,
|
|
input.TenantKey,
|
|
input.Provider,
|
|
input.PlatformID,
|
|
input.PlatformModelID,
|
|
input.SourceTaskID,
|
|
input.SourceAttemptID,
|
|
input.Model,
|
|
input.PreviewModel,
|
|
input.VoiceID,
|
|
input.DisplayName,
|
|
input.DemoAudioURL,
|
|
status,
|
|
input.ExpiresAt,
|
|
string(metadata),
|
|
))
|
|
}
|
|
|
|
func (s *Store) ListClonedVoices(ctx context.Context, user *auth.User) ([]ClonedVoice, error) {
|
|
gatewayUserID, userID := clonedVoiceUserKeys(user)
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT `+clonedVoiceColumns+`
|
|
FROM gateway_cloned_voices v
|
|
LEFT JOIN integration_platforms p ON p.id = v.platform_id
|
|
WHERE (
|
|
(
|
|
NULLIF($1, '')::uuid IS NOT NULL
|
|
AND v.gateway_user_id = NULLIF($1, '')::uuid
|
|
)
|
|
OR (
|
|
NULLIF($2, '') IS NOT NULL
|
|
AND v.user_id = $2
|
|
)
|
|
)
|
|
AND v.status <> 'deleted'
|
|
ORDER BY v.created_at DESC`, gatewayUserID, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
items := make([]ClonedVoice, 0)
|
|
for rows.Next() {
|
|
item, err := scanClonedVoice(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, rows.Err()
|
|
}
|
|
|
|
func (s *Store) FindClonedVoiceForUser(ctx context.Context, user *auth.User, clonedVoiceID string, voiceID string) (ClonedVoice, bool, error) {
|
|
gatewayUserID, userID := clonedVoiceUserKeys(user)
|
|
clonedVoiceID = strings.TrimSpace(clonedVoiceID)
|
|
voiceID = strings.TrimSpace(voiceID)
|
|
if clonedVoiceID == "" && voiceID == "" {
|
|
return ClonedVoice{}, false, nil
|
|
}
|
|
item, err := scanClonedVoice(s.pool.QueryRow(ctx, `
|
|
SELECT `+clonedVoiceColumns+`
|
|
FROM gateway_cloned_voices v
|
|
LEFT JOIN integration_platforms p ON p.id = v.platform_id
|
|
WHERE (
|
|
(
|
|
NULLIF($1, '')::uuid IS NOT NULL
|
|
AND v.gateway_user_id = NULLIF($1, '')::uuid
|
|
)
|
|
OR (
|
|
NULLIF($2, '') IS NOT NULL
|
|
AND v.user_id = $2
|
|
)
|
|
)
|
|
AND (
|
|
(NULLIF($3, '')::uuid IS NOT NULL AND v.id = NULLIF($3, '')::uuid)
|
|
OR (NULLIF($4, '') IS NOT NULL AND v.voice_id = $4)
|
|
)
|
|
AND v.status NOT IN ('deleted', 'failed')
|
|
ORDER BY CASE WHEN NULLIF($3, '')::uuid IS NOT NULL AND v.id = NULLIF($3, '')::uuid THEN 0 ELSE 1 END,
|
|
v.created_at DESC
|
|
LIMIT 1`, gatewayUserID, userID, clonedVoiceID, voiceID))
|
|
if err != nil {
|
|
if IsNotFound(err) {
|
|
return ClonedVoice{}, false, nil
|
|
}
|
|
return ClonedVoice{}, false, err
|
|
}
|
|
return item, true, nil
|
|
}
|
|
|
|
func (s *Store) DeleteClonedVoiceForUser(ctx context.Context, user *auth.User, clonedVoiceID string, voiceID string) (ClonedVoice, bool, error) {
|
|
gatewayUserID, userID := clonedVoiceUserKeys(user)
|
|
clonedVoiceID = strings.TrimSpace(clonedVoiceID)
|
|
voiceID = strings.TrimSpace(voiceID)
|
|
if clonedVoiceID == "" && voiceID == "" {
|
|
return ClonedVoice{}, false, nil
|
|
}
|
|
item, err := scanClonedVoice(s.pool.QueryRow(ctx, `
|
|
WITH updated AS (
|
|
UPDATE gateway_cloned_voices v
|
|
SET status = 'deleted', updated_at = now()
|
|
WHERE (
|
|
(
|
|
NULLIF($1, '')::uuid IS NOT NULL
|
|
AND v.gateway_user_id = NULLIF($1, '')::uuid
|
|
)
|
|
OR (
|
|
NULLIF($2, '') IS NOT NULL
|
|
AND v.user_id = $2
|
|
)
|
|
)
|
|
AND (
|
|
(NULLIF($3, '')::uuid IS NOT NULL AND v.id = NULLIF($3, '')::uuid)
|
|
OR (NULLIF($4, '') IS NOT NULL AND v.voice_id = $4)
|
|
)
|
|
AND v.status <> 'deleted'
|
|
RETURNING *
|
|
)
|
|
SELECT `+clonedVoiceColumns+`
|
|
FROM updated v
|
|
LEFT JOIN integration_platforms p ON p.id = v.platform_id`, gatewayUserID, userID, clonedVoiceID, voiceID))
|
|
if err != nil {
|
|
if IsNotFound(err) {
|
|
return ClonedVoice{}, false, nil
|
|
}
|
|
return ClonedVoice{}, false, err
|
|
}
|
|
return item, true, nil
|
|
}
|
|
|
|
func (s *Store) TouchClonedVoiceUsage(ctx context.Context, clonedVoiceID string) error {
|
|
if strings.TrimSpace(clonedVoiceID) == "" {
|
|
return nil
|
|
}
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_cloned_voices
|
|
SET last_used_at = now(), expires_at = now() + interval '7 days', updated_at = now()
|
|
WHERE id = $1::uuid`, clonedVoiceID)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) MarkClonedVoiceStatus(ctx context.Context, clonedVoiceID string, status string) error {
|
|
if strings.TrimSpace(clonedVoiceID) == "" || strings.TrimSpace(status) == "" {
|
|
return nil
|
|
}
|
|
_, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_cloned_voices
|
|
SET status = $2, updated_at = now()
|
|
WHERE id = $1::uuid`, clonedVoiceID, status)
|
|
return err
|
|
}
|
|
|
|
func clonedVoiceUserKeys(user *auth.User) (string, string) {
|
|
if user == nil {
|
|
return "", ""
|
|
}
|
|
gatewayUserID := strings.TrimSpace(user.GatewayUserID)
|
|
if gatewayUserID == "" && user.Source == "gateway" {
|
|
gatewayUserID = strings.TrimSpace(user.ID)
|
|
}
|
|
userID := strings.TrimSpace(user.ID)
|
|
return gatewayUserID, userID
|
|
}
|
|
|
|
type clonedVoiceScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanClonedVoice(scanner clonedVoiceScanner) (ClonedVoice, error) {
|
|
var item ClonedVoice
|
|
var metadata []byte
|
|
if err := scanner.Scan(
|
|
&item.ID,
|
|
&item.GatewayUserID,
|
|
&item.UserID,
|
|
&item.GatewayTenantID,
|
|
&item.TenantID,
|
|
&item.TenantKey,
|
|
&item.Provider,
|
|
&item.PlatformID,
|
|
&item.PlatformName,
|
|
&item.PlatformModelID,
|
|
&item.Model,
|
|
&item.PreviewModel,
|
|
&item.VoiceID,
|
|
&item.DisplayName,
|
|
&item.DemoAudioURL,
|
|
&item.Status,
|
|
&item.ExpiresAt,
|
|
&item.LastUsedAt,
|
|
&metadata,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return ClonedVoice{}, err
|
|
}
|
|
item.Metadata = decodeObject(metadata)
|
|
return item, nil
|
|
}
|