package runner import ( "context" "errors" "io" "log/slog" "os" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) type walletExecuteMockClient struct { calls atomic.Int32 } func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) { client.calls.Add(1) return clients.Response{ Result: map[string]any{"mock": true}, RequestID: "mock-wallet-execute", }, nil } func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) { databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL")) if databaseURL == "" { t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test") } ctx := context.Background() db, err := store.Connect(ctx, databaseURL) if err != nil { t.Fatalf("connect store: %v", err) } t.Cleanup(db.Close) suffix := strconv.FormatInt(time.Now().UnixNano(), 10) tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{ TenantKey: "wallet-execute-" + suffix, Name: "Wallet Execute Test " + suffix, }) if err != nil { t.Fatalf("create tenant: %v", err) } t.Cleanup(func() { _ = db.DeleteTenant(context.Background(), tenant.ID) }) gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{ UserKey: "wallet-execute-user-" + suffix, Username: "wallet_execute_" + suffix, GatewayTenantID: tenant.ID, TenantKey: tenant.TenantKey, Roles: []string{"user"}, }) if err != nil { t.Fatalf("create gateway user: %v", err) } t.Cleanup(func() { _ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID) }) platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{ Provider: "mock", PlatformKey: "wallet-execute-mock-" + suffix, Name: "Wallet Execute Mock " + suffix, AuthType: "none", Config: map[string]any{"specType": "mock"}, Status: "enabled", Priority: 1, }) if err != nil { t.Fatalf("create mock platform: %v", err) } t.Cleanup(func() { _ = db.DeletePlatform(context.Background(), platform.ID) }) if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{ PlatformID: platform.ID, ModelName: "mock-wallet-image", ProviderModelName: "mock-wallet-image", ModelType: store.StringList{"image_generate"}, DisplayName: "Mock Wallet Image", BillingConfig: map[string]any{ "image": map[string]any{"basePrice": 10}, }, }); err != nil { t.Fatalf("create mock platform model: %v", err) } user := &auth.User{ ID: gatewayUser.ID, Source: "gateway", GatewayUserID: gatewayUser.ID, GatewayTenantID: tenant.ID, TenantKey: tenant.TenantKey, Roles: gatewayUser.Roles, } if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{ GatewayUserID: gatewayUser.ID, Currency: "resource", Balance: 10, Reason: "seed wallet execute test", }); err != nil { t.Fatalf("seed wallet balance: %v", err) } tasks := make([]store.GatewayTask, 0, 2) for i := 0; i < 2; i++ { task, err := db.CreateTask(ctx, store.CreateTaskInput{ Kind: "images.generations", Model: "mock-wallet-image", Request: map[string]any{ "count": 1, "prompt": "wallet execute test", }, }, user) if err != nil { t.Fatalf("create task: %v", err) } tasks = append(tasks, task) } mockClient := &walletExecuteMockClient{} service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil))) service.clients["mock"] = mockClient type executeResult struct { result Result err error } results := make(chan executeResult, len(tasks)) var wg sync.WaitGroup for _, task := range tasks { task := task wg.Add(1) go func() { defer wg.Done() result, err := service.Execute(ctx, task, user) results <- executeResult{result: result, err: err} }() } wg.Wait() close(results) successCount := 0 insufficientCount := 0 for item := range results { if item.err == nil { successCount++ if item.result.Task.Status != "succeeded" { t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status) } if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) { t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount) } continue } if errors.Is(item.err, store.ErrInsufficientWalletBalance) { insufficientCount++ if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" { t.Fatalf("insufficient execution task = %+v", item.result.Task) } continue } t.Fatalf("unexpected execute error: %v", item.err) } if successCount != 1 || insufficientCount != 1 { t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount) } if got := mockClient.calls.Load(); got != 1 { t.Fatalf("mock client calls = %d, want 1", got) } summary, err := db.GetWalletSummary(ctx, user, "resource") if err != nil { t.Fatalf("get wallet summary: %v", err) } account := summary.PrimaryAccount if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) { t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent) } } func walletExecuteFloatNear(a float64, b float64) bool { delta := a - b if delta < 0 { delta = -delta } return delta < 0.000001 }