- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段 - 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输 - 添加 Gemini 客户端测试用例验证多媒体内容转换功能 - 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递 - 实现钱包计费预留机制,确保任务执行前余额充足 - 添加钱包冻结余额管理,防止并发操作导致的超扣问题 - 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还 - 优化数据库事务处理,确保计费操作的原子性和一致性 - 添加数据库集成测试验证迁移脚本执行流程 - 统一 Google Gemini 相关模型提供商标识符映射
172 lines
5.8 KiB
Go
172 lines
5.8 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
)
|
|
|
|
func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) {
|
|
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
|
if databaseURL == "" {
|
|
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
db, err := Connect(ctx, databaseURL)
|
|
if err != nil {
|
|
t.Fatalf("connect store: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
tenantID, userID := seedWalletReservationUser(t, ctx, db)
|
|
if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{
|
|
GatewayUserID: userID,
|
|
Currency: "resource",
|
|
Balance: 10,
|
|
Reason: "seed wallet reservation test",
|
|
}); err != nil {
|
|
t.Fatalf("seed wallet balance: %v", err)
|
|
}
|
|
|
|
firstTaskID := newWalletReservationTestUUID(t, ctx, db)
|
|
secondTaskID := newWalletReservationTestUUID(t, ctx, db)
|
|
billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}}
|
|
user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID}
|
|
tasks := []GatewayTask{
|
|
{ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"},
|
|
{ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"},
|
|
}
|
|
|
|
type reserveResult struct {
|
|
reservations []WalletBillingReservation
|
|
err error
|
|
}
|
|
results := make(chan reserveResult, len(tasks))
|
|
var wg sync.WaitGroup
|
|
for _, task := range tasks {
|
|
task := task
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
reservations, err := db.ReserveTaskBilling(ctx, task, user, billings)
|
|
results <- reserveResult{reservations: reservations, err: err}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
var successReservations []WalletBillingReservation
|
|
successCount := 0
|
|
insufficientCount := 0
|
|
for result := range results {
|
|
if result.err == nil {
|
|
successCount++
|
|
successReservations = result.reservations
|
|
continue
|
|
}
|
|
if errors.Is(result.err, ErrInsufficientWalletBalance) {
|
|
insufficientCount++
|
|
continue
|
|
}
|
|
t.Fatalf("unexpected reservation error: %v", result.err)
|
|
}
|
|
if successCount != 1 || insufficientCount != 1 {
|
|
t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
|
}
|
|
if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) {
|
|
t.Fatalf("unexpected successful reservations: %+v", successReservations)
|
|
}
|
|
|
|
balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID)
|
|
if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) {
|
|
t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
|
}
|
|
|
|
settleTask := GatewayTask{
|
|
ID: successReservations[0].TaskID,
|
|
GatewayUserID: userID,
|
|
GatewayTenantID: tenantID,
|
|
Kind: "images.generations",
|
|
Model: "mock-image",
|
|
ResolvedModel: "mock-image",
|
|
Billings: billings,
|
|
BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)},
|
|
FinalChargeAmount: 10,
|
|
}
|
|
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
|
t.Fatalf("settle reserved task billing: %v", err)
|
|
}
|
|
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
|
t.Fatalf("settle reserved task billing should be idempotent: %v", err)
|
|
}
|
|
balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID)
|
|
if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) {
|
|
t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
|
}
|
|
}
|
|
|
|
func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) {
|
|
t.Helper()
|
|
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
var tenantID string
|
|
if err := db.pool.QueryRow(ctx, `
|
|
INSERT INTO gateway_tenants (tenant_key, name)
|
|
VALUES ($1, $2)
|
|
RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil {
|
|
t.Fatalf("insert test tenant: %v", err)
|
|
}
|
|
var userID string
|
|
if err := db.pool.QueryRow(ctx, `
|
|
INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles)
|
|
VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb)
|
|
RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil {
|
|
t.Fatalf("insert test user: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
cleanupCtx := context.Background()
|
|
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID)
|
|
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID)
|
|
})
|
|
return tenantID, userID
|
|
}
|
|
|
|
func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string {
|
|
t.Helper()
|
|
var id string
|
|
if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil {
|
|
t.Fatalf("generate uuid: %v", err)
|
|
}
|
|
return id
|
|
}
|
|
|
|
func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) {
|
|
t.Helper()
|
|
var balance float64
|
|
var frozen float64
|
|
var spent float64
|
|
if err := db.pool.QueryRow(ctx, `
|
|
SELECT balance::float8, frozen_balance::float8, total_spent::float8
|
|
FROM gateway_wallet_accounts
|
|
WHERE gateway_user_id = $1::uuid
|
|
AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil {
|
|
t.Fatalf("read wallet account: %v", err)
|
|
}
|
|
return balance, frozen, spent
|
|
}
|
|
|
|
func walletFloatNear(a float64, b float64) bool {
|
|
delta := a - b
|
|
if delta < 0 {
|
|
delta = -delta
|
|
}
|
|
return delta < 0.000001
|
|
}
|