easyai-ai-gateway/apps/api/internal/store/wallet.go

563 lines
20 KiB
Go

package store
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
)
type GatewayWalletAccount struct {
ID string `json:"id"`
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
GatewayUserID string `json:"gatewayUserId"`
TenantID string `json:"tenantId,omitempty"`
TenantKey string `json:"tenantKey,omitempty"`
UserID string `json:"userId,omitempty"`
Currency string `json:"currency"`
Balance float64 `json:"balance"`
FrozenBalance float64 `json:"frozenBalance"`
TotalRecharged float64 `json:"totalRecharged"`
TotalSpent float64 `json:"totalSpent"`
Status string `json:"status"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type GatewayWalletTransaction struct {
ID string `json:"id"`
AccountID string `json:"accountId"`
Currency string `json:"currency,omitempty"`
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
GatewayUserID string `json:"gatewayUserId,omitempty"`
Direction string `json:"direction"`
TransactionType string `json:"transactionType"`
Amount float64 `json:"amount"`
BalanceBefore float64 `json:"balanceBefore"`
BalanceAfter float64 `json:"balanceAfter"`
IdempotencyKey string `json:"idempotencyKey,omitempty"`
ReferenceType string `json:"referenceType,omitempty"`
ReferenceID string `json:"referenceId,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
type WalletAvailability struct {
Account GatewayWalletAccount `json:"account"`
Currency string `json:"currency"`
RequiredAmount float64 `json:"requiredAmount"`
AvailableAmount float64 `json:"availableAmount"`
Enough bool `json:"enough"`
}
type WalletSummary struct {
Accounts []GatewayWalletAccount `json:"accounts"`
PrimaryAccount GatewayWalletAccount `json:"primaryAccount"`
}
type WalletTransactionListFilter struct {
Query string
Direction string
TransactionType string
CreatedFrom *time.Time
CreatedTo *time.Time
Page int
PageSize int
}
type WalletTransactionListResult struct {
Items []GatewayWalletTransaction
Total int
Page int
PageSize int
}
type WalletBalanceAdjustmentInput struct {
GatewayUserID string `json:"gatewayUserId"`
Currency string `json:"currency"`
Balance float64 `json:"balance"`
Reason string `json:"reason"`
IdempotencyKey string `json:"idempotencyKey"`
Metadata map[string]any `json:"metadata"`
}
type WalletAdjustmentResult struct {
Account GatewayWalletAccount `json:"account"`
Before GatewayWalletAccount `json:"before"`
Transaction GatewayWalletTransaction `json:"transaction"`
}
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
return WalletAvailability{Currency: normalizeWalletCurrency(currency), RequiredAmount: requiredAmount, Enough: true}, nil
}
account, err := s.ensureWalletAccount(ctx, s.pool, gatewayUserID, currency)
if err != nil {
return WalletAvailability{}, err
}
available := roundMoney(account.Balance - account.FrozenBalance)
result := WalletAvailability{
Account: account,
Currency: account.Currency,
RequiredAmount: roundMoney(requiredAmount),
AvailableAmount: available,
Enough: available+0.000001 >= requiredAmount,
}
if !result.Enough {
return result, fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, requiredAmount, account.Currency, available)
}
return result, nil
}
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
account := GatewayWalletAccount{
Currency: normalizeWalletCurrency(currency),
Status: "active",
}
return WalletSummary{Accounts: []GatewayWalletAccount{account}, PrimaryAccount: account}, nil
}
primary, err := s.ensureWalletAccount(ctx, s.pool, gatewayUserID, currency)
if err != nil {
return WalletSummary{}, err
}
rows, err := s.pool.Query(ctx, `
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
total_spent::float8, status, metadata, created_at, updated_at
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
ORDER BY CASE WHEN currency = $2 THEN 0 WHEN currency = 'resource' THEN 1 ELSE 2 END, currency ASC`, gatewayUserID, primary.Currency)
if err != nil {
return WalletSummary{}, err
}
defer rows.Close()
accounts := make([]GatewayWalletAccount, 0)
for rows.Next() {
account, err := scanWalletAccount(rows)
if err != nil {
return WalletSummary{}, err
}
accounts = append(accounts, account)
}
if err := rows.Err(); err != nil {
return WalletSummary{}, err
}
if len(accounts) == 0 {
accounts = append(accounts, primary)
}
return WalletSummary{Accounts: accounts, PrimaryAccount: accounts[0]}, nil
}
func (s *Store) ListWalletTransactions(ctx context.Context, user *auth.User, filter WalletTransactionListFilter) (WalletTransactionListResult, error) {
page := filter.Page
if page <= 0 {
page = 1
}
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
if pageSize > 100 {
pageSize = 100
}
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
return WalletTransactionListResult{Items: []GatewayWalletTransaction{}, Page: page, PageSize: pageSize}, nil
}
queryPattern := ""
if query := strings.TrimSpace(filter.Query); query != "" {
queryPattern = "%" + query + "%"
}
args := []any{
gatewayUserID,
queryPattern,
strings.TrimSpace(filter.Direction),
strings.TrimSpace(filter.TransactionType),
nullableTaskListTime(filter.CreatedFrom),
nullableTaskListTime(filter.CreatedTo),
}
whereSQL := `
WHERE a.gateway_user_id = $1::uuid
AND (
NULLIF($2, '') IS NULL
OR t.id::text ILIKE $2
OR COALESCE(t.reference_id, '') ILIKE $2
OR COALESCE(t.reference_type, '') ILIKE $2
OR COALESCE(t.idempotency_key, '') ILIKE $2
OR t.transaction_type ILIKE $2
OR t.direction ILIKE $2
OR COALESCE(task.id::text, '') ILIKE $2
OR COALESCE(task.request_id, '') ILIKE $2
OR COALESCE(task.kind, '') ILIKE $2
OR COALESCE(task.model, '') ILIKE $2
OR COALESCE(task.requested_model, '') ILIKE $2
OR COALESCE(task.resolved_model, '') ILIKE $2
OR COALESCE(task.model_type, '') ILIKE $2
OR COALESCE(task.api_key_id, '') ILIKE $2
OR COALESCE(task.api_key_name, '') ILIKE $2
OR COALESCE(task.api_key_prefix, '') ILIKE $2
OR COALESCE(task.status, '') ILIKE $2
OR COALESCE(task.billing_summary->>'currency', '') ILIKE $2
OR COALESCE(task.billing_summary->>'totalAmount', '') ILIKE $2
OR COALESCE(attempt.client_id, '') ILIKE $2
OR COALESCE(attempt.request_id, '') ILIKE $2
OR COALESCE(platform.provider, '') ILIKE $2
OR COALESCE(platform.platform_key, '') ILIKE $2
OR COALESCE(platform.name, '') ILIKE $2
OR COALESCE(platform_model.model_name, '') ILIKE $2
OR COALESCE(platform_model.provider_model_name, '') ILIKE $2
OR COALESCE(platform_model.model_alias, '') ILIKE $2
OR COALESCE(platform_model.display_name, '') ILIKE $2
OR COALESCE(task.metrics->>'provider', '') ILIKE $2
OR COALESCE(task.metrics->>'platformName', '') ILIKE $2
OR COALESCE(task.metrics->>'modelAlias', '') ILIKE $2
OR COALESCE(task.metrics->>'providerModel', '') ILIKE $2
)
AND (NULLIF($3, '') IS NULL OR t.direction = $3)
AND (NULLIF($4, '') IS NULL OR t.transaction_type = $4)
AND ($5::timestamptz IS NULL OR t.created_at >= $5::timestamptz)
AND ($6::timestamptz IS NULL OR t.created_at <= $6::timestamptz)`
var total int
if err := s.pool.QueryRow(ctx, `
SELECT count(*)
FROM gateway_wallet_transactions t
JOIN gateway_wallet_accounts a ON a.id = t.account_id
LEFT JOIN gateway_tasks task ON t.reference_type = 'gateway_task' AND t.reference_id = task.id::text
LEFT JOIN LATERAL (
SELECT platform_id, platform_model_id, client_id, request_id
FROM gateway_task_attempts
WHERE task_id = task.id
ORDER BY attempt_no DESC, started_at DESC
LIMIT 1
) attempt ON true
LEFT JOIN integration_platforms platform ON platform.id = attempt.platform_id
LEFT JOIN platform_models platform_model ON platform_model.id = attempt.platform_model_id
`+whereSQL, args...).Scan(&total); err != nil {
return WalletTransactionListResult{}, err
}
offset := (page - 1) * pageSize
queryArgs := append(args, pageSize, offset)
rows, err := s.pool.Query(ctx, `
SELECT t.id::text, t.account_id::text, a.currency, COALESCE(t.gateway_tenant_id::text, ''),
COALESCE(t.gateway_user_id::text, ''), t.direction, t.transaction_type,
t.amount::float8, t.balance_before::float8, t.balance_after::float8,
COALESCE(t.idempotency_key, ''), COALESCE(t.reference_type, ''),
COALESCE(t.reference_id, ''),
t.metadata || jsonb_strip_nulls(jsonb_build_object(
'taskId', task.id::text,
'kind', task.kind,
'model', task.model,
'requestedModel', task.requested_model,
'resolvedModel', task.resolved_model,
'modelType', task.model_type,
'taskStatus', task.status,
'runMode', task.run_mode,
'requestId', COALESCE(task.request_id, attempt.request_id),
'apiKeyId', task.api_key_id,
'apiKeyName', task.api_key_name,
'apiKeyPrefix', task.api_key_prefix,
'provider', COALESCE(platform.provider, task.metrics->>'provider'),
'platformId', COALESCE(platform.id::text, task.metrics->>'platformId'),
'platformName', COALESCE(platform.name, task.metrics->>'platformName'),
'platformKey', platform.platform_key,
'platformModelId', COALESCE(platform_model.id::text, task.metrics->>'platformModelId'),
'platformModelName', platform_model.model_name,
'platformModelAlias', platform_model.model_alias,
'providerModel', COALESCE(platform_model.provider_model_name, task.metrics->>'providerModel'),
'clientId', attempt.client_id,
'usage', CASE WHEN task.id IS NULL THEN NULL ELSE COALESCE(task.usage, attempt.usage, '{}'::jsonb) END,
'billings', CASE WHEN task.id IS NULL THEN NULL ELSE COALESCE(task.billings, '[]'::jsonb) END,
'billingSummary', CASE WHEN task.id IS NULL THEN NULL ELSE COALESCE(task.billing_summary, '{}'::jsonb) END,
'finalChargeAmount', CASE WHEN task.id IS NULL THEN NULL ELSE COALESCE(task.final_charge_amount, 0)::float8 END,
'responseStartedAt', COALESCE(task.response_started_at::text, attempt.response_started_at::text),
'responseFinishedAt', COALESCE(task.response_finished_at::text, attempt.response_finished_at::text),
'responseDurationMs', COALESCE(task.response_duration_ms, attempt.response_duration_ms)
)), t.created_at
FROM gateway_wallet_transactions t
JOIN gateway_wallet_accounts a ON a.id = t.account_id
LEFT JOIN gateway_tasks task ON t.reference_type = 'gateway_task' AND t.reference_id = task.id::text
LEFT JOIN LATERAL (
SELECT platform_id, platform_model_id, client_id, request_id, usage, response_started_at,
response_finished_at, response_duration_ms
FROM gateway_task_attempts
WHERE task_id = task.id
ORDER BY attempt_no DESC, started_at DESC
LIMIT 1
) attempt ON true
LEFT JOIN integration_platforms platform ON platform.id = attempt.platform_id
LEFT JOIN platform_models platform_model ON platform_model.id = attempt.platform_model_id
`+whereSQL+`
ORDER BY t.created_at DESC, t.id DESC
LIMIT $7 OFFSET $8`, queryArgs...)
if err != nil {
return WalletTransactionListResult{}, err
}
defer rows.Close()
items := make([]GatewayWalletTransaction, 0)
for rows.Next() {
item, err := scanWalletTransactionWithCurrency(rows)
if err != nil {
return WalletTransactionListResult{}, err
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return WalletTransactionListResult{}, err
}
return WalletTransactionListResult{Items: items, Total: total, Page: page, PageSize: pageSize}, nil
}
func (s *Store) SetUserWalletBalance(ctx context.Context, input WalletBalanceAdjustmentInput) (WalletAdjustmentResult, error) {
var result WalletAdjustmentResult
err := s.InTx(ctx, func(tx Tx) error {
next, err := s.SetUserWalletBalanceTx(ctx, tx, input)
if err != nil {
return err
}
result = next
return nil
})
return result, err
}
func (s *Store) SetUserWalletBalanceTx(ctx context.Context, tx Tx, input WalletBalanceAdjustmentInput) (WalletAdjustmentResult, error) {
input.GatewayUserID = strings.TrimSpace(input.GatewayUserID)
if input.GatewayUserID == "" {
return WalletAdjustmentResult{}, ErrLocalUserRequired
}
if input.Balance < 0 {
return WalletAdjustmentResult{}, fmt.Errorf("wallet balance cannot be negative")
}
account, err := s.ensureWalletAccount(ctx, tx, input.GatewayUserID, input.Currency)
if err != nil {
return WalletAdjustmentResult{}, err
}
var locked GatewayWalletAccount
locked, err = scanWalletAccount(tx.QueryRow(ctx, `
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
total_spent::float8, status, metadata, created_at, updated_at
FROM gateway_wallet_accounts
WHERE id = $1::uuid
FOR UPDATE`, account.ID))
if err != nil {
return WalletAdjustmentResult{}, err
}
before := locked
nextBalance := roundMoney(input.Balance)
delta := roundMoney(nextBalance - locked.Balance)
if delta == 0 {
return WalletAdjustmentResult{}, ErrWalletBalanceUnchanged
}
direction := "credit"
amount := delta
if delta < 0 {
direction = "debit"
amount = -delta
}
reason := strings.TrimSpace(input.Reason)
if reason == "" {
reason = "后台余额调整"
}
if _, err := tx.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET balance = $2,
total_recharged = total_recharged + CASE WHEN $3 = 'credit' THEN $4 ELSE 0 END,
updated_at = now()
WHERE id = $1::uuid`,
locked.ID,
nextBalance,
direction,
amount,
); err != nil {
return WalletAdjustmentResult{}, err
}
metadata := mergeObjects(input.Metadata, map[string]any{
"reason": reason,
"previousBalance": roundMoney(before.Balance),
"targetBalance": nextBalance,
})
metadataJSON, _ := json.Marshal(emptyObjectIfNil(metadata))
transaction, err := scanWalletTransaction(tx.QueryRow(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, $4, 'admin_adjust',
$5, $6, $7, NULLIF($8, ''), 'gateway_user', $9, $10::jsonb
)
RETURNING id::text, account_id::text, COALESCE(gateway_tenant_id::text, ''), COALESCE(gateway_user_id::text, ''),
direction, transaction_type, amount::float8, balance_before::float8, balance_after::float8,
COALESCE(idempotency_key, ''), COALESCE(reference_type, ''), COALESCE(reference_id, ''),
metadata, created_at`,
locked.ID,
locked.GatewayTenantID,
locked.GatewayUserID,
direction,
amount,
roundMoney(before.Balance),
nextBalance,
strings.TrimSpace(input.IdempotencyKey),
locked.GatewayUserID,
string(metadataJSON),
))
if err != nil {
return WalletAdjustmentResult{}, err
}
locked.Balance = nextBalance
if direction == "credit" {
locked.TotalRecharged = roundMoney(locked.TotalRecharged + amount)
}
locked.UpdatedAt = time.Now()
return WalletAdjustmentResult{Account: locked, Before: before, Transaction: transaction}, nil
}
func (s *Store) ensureWalletAccount(ctx context.Context, q Tx, gatewayUserID string, currency string) (GatewayWalletAccount, error) {
currency = normalizeWalletCurrency(currency)
if _, err := q.Exec(ctx, `
INSERT INTO gateway_wallet_accounts (
gateway_tenant_id, gateway_user_id, tenant_id, tenant_key, user_id, currency
)
SELECT gateway_tenant_id, id, NULLIF(tenant_id, ''), NULLIF(tenant_key, ''),
COALESCE(NULLIF(external_user_id, ''), id::text), $2
FROM gateway_users
WHERE id = $1::uuid
AND deleted_at IS NULL
ON CONFLICT (gateway_user_id, currency) DO UPDATE
SET gateway_tenant_id = COALESCE(gateway_wallet_accounts.gateway_tenant_id, EXCLUDED.gateway_tenant_id),
tenant_id = COALESCE(NULLIF(gateway_wallet_accounts.tenant_id, ''), EXCLUDED.tenant_id),
tenant_key = COALESCE(NULLIF(gateway_wallet_accounts.tenant_key, ''), EXCLUDED.tenant_key),
user_id = COALESCE(NULLIF(gateway_wallet_accounts.user_id, ''), EXCLUDED.user_id),
updated_at = now()
WHERE gateway_wallet_accounts.gateway_tenant_id IS NULL
OR COALESCE(gateway_wallet_accounts.tenant_id, '') = ''
OR COALESCE(gateway_wallet_accounts.tenant_key, '') = ''
OR COALESCE(gateway_wallet_accounts.user_id, '') = ''`, gatewayUserID, currency); err != nil {
return GatewayWalletAccount{}, err
}
account, err := scanWalletAccount(q.QueryRow(ctx, `
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
total_spent::float8, status, metadata, created_at, updated_at
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = $2`, gatewayUserID, currency))
if err != nil {
if err == pgx.ErrNoRows {
return GatewayWalletAccount{}, pgx.ErrNoRows
}
return GatewayWalletAccount{}, err
}
return account, nil
}
func normalizeWalletCurrency(currency string) string {
currency = strings.TrimSpace(currency)
if currency == "" {
return "resource"
}
return currency
}
func scanWalletAccount(row scanner) (GatewayWalletAccount, error) {
var item GatewayWalletAccount
var metadata []byte
if err := row.Scan(
&item.ID,
&item.GatewayTenantID,
&item.GatewayUserID,
&item.TenantID,
&item.TenantKey,
&item.UserID,
&item.Currency,
&item.Balance,
&item.FrozenBalance,
&item.TotalRecharged,
&item.TotalSpent,
&item.Status,
&metadata,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return GatewayWalletAccount{}, err
}
item.Metadata = decodeObject(metadata)
return item, nil
}
func scanWalletTransaction(row scanner) (GatewayWalletTransaction, error) {
var item GatewayWalletTransaction
var metadata []byte
if err := row.Scan(
&item.ID,
&item.AccountID,
&item.GatewayTenantID,
&item.GatewayUserID,
&item.Direction,
&item.TransactionType,
&item.Amount,
&item.BalanceBefore,
&item.BalanceAfter,
&item.IdempotencyKey,
&item.ReferenceType,
&item.ReferenceID,
&metadata,
&item.CreatedAt,
); err != nil {
return GatewayWalletTransaction{}, err
}
item.Metadata = decodeObject(metadata)
return item, nil
}
func scanWalletTransactionWithCurrency(row scanner) (GatewayWalletTransaction, error) {
var item GatewayWalletTransaction
var metadata []byte
if err := row.Scan(
&item.ID,
&item.AccountID,
&item.Currency,
&item.GatewayTenantID,
&item.GatewayUserID,
&item.Direction,
&item.TransactionType,
&item.Amount,
&item.BalanceBefore,
&item.BalanceAfter,
&item.IdempotencyKey,
&item.ReferenceType,
&item.ReferenceID,
&metadata,
&item.CreatedAt,
); err != nil {
return GatewayWalletTransaction{}, err
}
item.Metadata = decodeObject(metadata)
return item, nil
}
func decodeWalletAccounts(data []byte) []GatewayWalletAccount {
if len(data) == 0 {
return nil
}
var items []GatewayWalletAccount
if err := json.Unmarshal(data, &items); err != nil {
return nil
}
return items
}