- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段 - 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输 - 添加 Gemini 客户端测试用例验证多媒体内容转换功能 - 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递 - 实现钱包计费预留机制,确保任务执行前余额充足 - 添加钱包冻结余额管理,防止并发操作导致的超扣问题 - 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还 - 优化数据库事务处理,确保计费操作的原子性和一致性 - 添加数据库集成测试验证迁移脚本执行流程 - 统一 Google Gemini 相关模型提供商标识符映射
203 lines
5.8 KiB
Go
203 lines
5.8 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log/slog"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
type walletExecuteMockClient struct {
|
|
calls atomic.Int32
|
|
}
|
|
|
|
func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
|
client.calls.Add(1)
|
|
return clients.Response{
|
|
Result: map[string]any{"mock": true},
|
|
RequestID: "mock-wallet-execute",
|
|
}, nil
|
|
}
|
|
|
|
func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) {
|
|
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
|
if databaseURL == "" {
|
|
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
db, err := store.Connect(ctx, databaseURL)
|
|
if err != nil {
|
|
t.Fatalf("connect store: %v", err)
|
|
}
|
|
t.Cleanup(db.Close)
|
|
|
|
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
|
tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{
|
|
TenantKey: "wallet-execute-" + suffix,
|
|
Name: "Wallet Execute Test " + suffix,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create tenant: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
_ = db.DeleteTenant(context.Background(), tenant.ID)
|
|
})
|
|
|
|
gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{
|
|
UserKey: "wallet-execute-user-" + suffix,
|
|
Username: "wallet_execute_" + suffix,
|
|
GatewayTenantID: tenant.ID,
|
|
TenantKey: tenant.TenantKey,
|
|
Roles: []string{"user"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create gateway user: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
_ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID)
|
|
})
|
|
|
|
platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{
|
|
Provider: "mock",
|
|
PlatformKey: "wallet-execute-mock-" + suffix,
|
|
Name: "Wallet Execute Mock " + suffix,
|
|
AuthType: "none",
|
|
Config: map[string]any{"specType": "mock"},
|
|
Status: "enabled",
|
|
Priority: 1,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("create mock platform: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
_ = db.DeletePlatform(context.Background(), platform.ID)
|
|
})
|
|
|
|
if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{
|
|
PlatformID: platform.ID,
|
|
ModelName: "mock-wallet-image",
|
|
ProviderModelName: "mock-wallet-image",
|
|
ModelType: store.StringList{"image_generate"},
|
|
DisplayName: "Mock Wallet Image",
|
|
BillingConfig: map[string]any{
|
|
"image": map[string]any{"basePrice": 10},
|
|
},
|
|
}); err != nil {
|
|
t.Fatalf("create mock platform model: %v", err)
|
|
}
|
|
|
|
user := &auth.User{
|
|
ID: gatewayUser.ID,
|
|
Source: "gateway",
|
|
GatewayUserID: gatewayUser.ID,
|
|
GatewayTenantID: tenant.ID,
|
|
TenantKey: tenant.TenantKey,
|
|
Roles: gatewayUser.Roles,
|
|
}
|
|
if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{
|
|
GatewayUserID: gatewayUser.ID,
|
|
Currency: "resource",
|
|
Balance: 10,
|
|
Reason: "seed wallet execute test",
|
|
}); err != nil {
|
|
t.Fatalf("seed wallet balance: %v", err)
|
|
}
|
|
|
|
tasks := make([]store.GatewayTask, 0, 2)
|
|
for i := 0; i < 2; i++ {
|
|
task, err := db.CreateTask(ctx, store.CreateTaskInput{
|
|
Kind: "images.generations",
|
|
Model: "mock-wallet-image",
|
|
Request: map[string]any{
|
|
"count": 1,
|
|
"prompt": "wallet execute test",
|
|
},
|
|
}, user)
|
|
if err != nil {
|
|
t.Fatalf("create task: %v", err)
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
|
|
mockClient := &walletExecuteMockClient{}
|
|
service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
|
service.clients["mock"] = mockClient
|
|
|
|
type executeResult struct {
|
|
result Result
|
|
err error
|
|
}
|
|
results := make(chan executeResult, len(tasks))
|
|
var wg sync.WaitGroup
|
|
for _, task := range tasks {
|
|
task := task
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
result, err := service.Execute(ctx, task, user)
|
|
results <- executeResult{result: result, err: err}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
successCount := 0
|
|
insufficientCount := 0
|
|
for item := range results {
|
|
if item.err == nil {
|
|
successCount++
|
|
if item.result.Task.Status != "succeeded" {
|
|
t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status)
|
|
}
|
|
if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) {
|
|
t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount)
|
|
}
|
|
continue
|
|
}
|
|
if errors.Is(item.err, store.ErrInsufficientWalletBalance) {
|
|
insufficientCount++
|
|
if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" {
|
|
t.Fatalf("insufficient execution task = %+v", item.result.Task)
|
|
}
|
|
continue
|
|
}
|
|
t.Fatalf("unexpected execute error: %v", item.err)
|
|
}
|
|
if successCount != 1 || insufficientCount != 1 {
|
|
t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
|
}
|
|
if got := mockClient.calls.Load(); got != 1 {
|
|
t.Fatalf("mock client calls = %d, want 1", got)
|
|
}
|
|
|
|
summary, err := db.GetWalletSummary(ctx, user, "resource")
|
|
if err != nil {
|
|
t.Fatalf("get wallet summary: %v", err)
|
|
}
|
|
account := summary.PrimaryAccount
|
|
if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) {
|
|
t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent)
|
|
}
|
|
}
|
|
|
|
func walletExecuteFloatNear(a float64, b float64) bool {
|
|
delta := a - b
|
|
if delta < 0 {
|
|
delta = -delta
|
|
}
|
|
return delta < 0.000001
|
|
}
|