easyai-ai-gateway/apps/api/internal/runner/wallet_execute_test.go
wangbo 8ad5b06c18 feat(api): 添加多媒体内容支持并优化钱包计费系统
- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段
- 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输
- 添加 Gemini 客户端测试用例验证多媒体内容转换功能
- 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递
- 实现钱包计费预留机制,确保任务执行前余额充足
- 添加钱包冻结余额管理,防止并发操作导致的超扣问题
- 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还
- 优化数据库事务处理,确保计费操作的原子性和一致性
- 添加数据库集成测试验证迁移脚本执行流程
- 统一 Google Gemini 相关模型提供商标识符映射
2026-05-22 23:46:08 +08:00

203 lines
5.8 KiB
Go

package runner
import (
"context"
"errors"
"io"
"log/slog"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type walletExecuteMockClient struct {
calls atomic.Int32
}
func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) {
client.calls.Add(1)
return clients.Response{
Result: map[string]any{"mock": true},
RequestID: "mock-wallet-execute",
}, nil
}
func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) {
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
if databaseURL == "" {
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test")
}
ctx := context.Background()
db, err := store.Connect(ctx, databaseURL)
if err != nil {
t.Fatalf("connect store: %v", err)
}
t.Cleanup(db.Close)
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{
TenantKey: "wallet-execute-" + suffix,
Name: "Wallet Execute Test " + suffix,
})
if err != nil {
t.Fatalf("create tenant: %v", err)
}
t.Cleanup(func() {
_ = db.DeleteTenant(context.Background(), tenant.ID)
})
gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{
UserKey: "wallet-execute-user-" + suffix,
Username: "wallet_execute_" + suffix,
GatewayTenantID: tenant.ID,
TenantKey: tenant.TenantKey,
Roles: []string{"user"},
})
if err != nil {
t.Fatalf("create gateway user: %v", err)
}
t.Cleanup(func() {
_ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID)
})
platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{
Provider: "mock",
PlatformKey: "wallet-execute-mock-" + suffix,
Name: "Wallet Execute Mock " + suffix,
AuthType: "none",
Config: map[string]any{"specType": "mock"},
Status: "enabled",
Priority: 1,
})
if err != nil {
t.Fatalf("create mock platform: %v", err)
}
t.Cleanup(func() {
_ = db.DeletePlatform(context.Background(), platform.ID)
})
if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{
PlatformID: platform.ID,
ModelName: "mock-wallet-image",
ProviderModelName: "mock-wallet-image",
ModelType: store.StringList{"image_generate"},
DisplayName: "Mock Wallet Image",
BillingConfig: map[string]any{
"image": map[string]any{"basePrice": 10},
},
}); err != nil {
t.Fatalf("create mock platform model: %v", err)
}
user := &auth.User{
ID: gatewayUser.ID,
Source: "gateway",
GatewayUserID: gatewayUser.ID,
GatewayTenantID: tenant.ID,
TenantKey: tenant.TenantKey,
Roles: gatewayUser.Roles,
}
if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{
GatewayUserID: gatewayUser.ID,
Currency: "resource",
Balance: 10,
Reason: "seed wallet execute test",
}); err != nil {
t.Fatalf("seed wallet balance: %v", err)
}
tasks := make([]store.GatewayTask, 0, 2)
for i := 0; i < 2; i++ {
task, err := db.CreateTask(ctx, store.CreateTaskInput{
Kind: "images.generations",
Model: "mock-wallet-image",
Request: map[string]any{
"count": 1,
"prompt": "wallet execute test",
},
}, user)
if err != nil {
t.Fatalf("create task: %v", err)
}
tasks = append(tasks, task)
}
mockClient := &walletExecuteMockClient{}
service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil)))
service.clients["mock"] = mockClient
type executeResult struct {
result Result
err error
}
results := make(chan executeResult, len(tasks))
var wg sync.WaitGroup
for _, task := range tasks {
task := task
wg.Add(1)
go func() {
defer wg.Done()
result, err := service.Execute(ctx, task, user)
results <- executeResult{result: result, err: err}
}()
}
wg.Wait()
close(results)
successCount := 0
insufficientCount := 0
for item := range results {
if item.err == nil {
successCount++
if item.result.Task.Status != "succeeded" {
t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status)
}
if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) {
t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount)
}
continue
}
if errors.Is(item.err, store.ErrInsufficientWalletBalance) {
insufficientCount++
if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" {
t.Fatalf("insufficient execution task = %+v", item.result.Task)
}
continue
}
t.Fatalf("unexpected execute error: %v", item.err)
}
if successCount != 1 || insufficientCount != 1 {
t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
}
if got := mockClient.calls.Load(); got != 1 {
t.Fatalf("mock client calls = %d, want 1", got)
}
summary, err := db.GetWalletSummary(ctx, user, "resource")
if err != nil {
t.Fatalf("get wallet summary: %v", err)
}
account := summary.PrimaryAccount
if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) {
t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent)
}
}
func walletExecuteFloatNear(a float64, b float64) bool {
delta := a - b
if delta < 0 {
delta = -delta
}
return delta < 0.000001
}