package store import ( "context" "errors" "os" "strconv" "strings" "sync" "testing" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" ) func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) { databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL")) if databaseURL == "" { t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test") } ctx := context.Background() db, err := Connect(ctx, databaseURL) if err != nil { t.Fatalf("connect store: %v", err) } defer db.Close() tenantID, userID := seedWalletReservationUser(t, ctx, db) if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{ GatewayUserID: userID, Currency: "resource", Balance: 10, Reason: "seed wallet reservation test", }); err != nil { t.Fatalf("seed wallet balance: %v", err) } firstTaskID := newWalletReservationTestUUID(t, ctx, db) secondTaskID := newWalletReservationTestUUID(t, ctx, db) billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}} user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID} tasks := []GatewayTask{ {ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"}, {ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"}, } type reserveResult struct { reservations []WalletBillingReservation err error } results := make(chan reserveResult, len(tasks)) var wg sync.WaitGroup for _, task := range tasks { task := task wg.Add(1) go func() { defer wg.Done() reservations, err := db.ReserveTaskBilling(ctx, task, user, billings) results <- reserveResult{reservations: reservations, err: err} }() } wg.Wait() close(results) var successReservations []WalletBillingReservation successCount := 0 insufficientCount := 0 for result := range results { if result.err == nil { successCount++ successReservations = result.reservations continue } if errors.Is(result.err, ErrInsufficientWalletBalance) { insufficientCount++ continue } t.Fatalf("unexpected reservation error: %v", result.err) } if successCount != 1 || insufficientCount != 1 { t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount) } if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) { t.Fatalf("unexpected successful reservations: %+v", successReservations) } balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID) if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) { t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent) } settleTask := GatewayTask{ ID: successReservations[0].TaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image", ResolvedModel: "mock-image", Billings: billings, BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)}, FinalChargeAmount: 10, } if err := db.SettleTaskBilling(ctx, settleTask); err != nil { t.Fatalf("settle reserved task billing: %v", err) } if err := db.SettleTaskBilling(ctx, settleTask); err != nil { t.Fatalf("settle reserved task billing should be idempotent: %v", err) } balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID) if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) { t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent) } } func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) { t.Helper() suffix := strconv.FormatInt(time.Now().UnixNano(), 10) var tenantID string if err := db.pool.QueryRow(ctx, ` INSERT INTO gateway_tenants (tenant_key, name) VALUES ($1, $2) RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil { t.Fatalf("insert test tenant: %v", err) } var userID string if err := db.pool.QueryRow(ctx, ` INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles) VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb) RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil { t.Fatalf("insert test user: %v", err) } t.Cleanup(func() { cleanupCtx := context.Background() _, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID) _, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID) }) return tenantID, userID } func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string { t.Helper() var id string if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil { t.Fatalf("generate uuid: %v", err) } return id } func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) { t.Helper() var balance float64 var frozen float64 var spent float64 if err := db.pool.QueryRow(ctx, ` SELECT balance::float8, frozen_balance::float8, total_spent::float8 FROM gateway_wallet_accounts WHERE gateway_user_id = $1::uuid AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil { t.Fatalf("read wallet account: %v", err) } return balance, frozen, spent } func walletFloatNear(a float64, b float64) bool { delta := a - b if delta < 0 { delta = -delta } return delta < 0.000001 }