feat(api): 添加多媒体内容支持并优化钱包计费系统

- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段
- 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输
- 添加 Gemini 客户端测试用例验证多媒体内容转换功能
- 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递
- 实现钱包计费预留机制,确保任务执行前余额充足
- 添加钱包冻结余额管理,防止并发操作导致的超扣问题
- 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还
- 优化数据库事务处理,确保计费操作的原子性和一致性
- 添加数据库集成测试验证迁移脚本执行流程
- 统一 Google Gemini 相关模型提供商标识符映射
This commit is contained in:
wangbo 2026-05-22 23:46:08 +08:00
parent af9b281d34
commit 8ad5b06c18
12 changed files with 1104 additions and 82 deletions

View File

@ -569,6 +569,70 @@ func TestGeminiClientChatContract(t *testing.T) {
}
}
func TestGeminiClientChatConvertsMediaContentParts(t *testing.T) {
var captured map[string]any
var gotPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if err := json.NewDecoder(r.Body).Decode(&captured); err != nil {
t.Fatalf("decode request: %v", err)
}
_ = json.NewEncoder(w).Encode(map[string]any{
"candidates": []any{map[string]any{
"content": map[string]any{"parts": []any{map[string]any{"text": "video ok"}}},
}},
})
}))
defer server.Close()
_, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
Kind: "chat.completions",
Model: "gemini:gemini-2.5-flash",
Body: map[string]any{
"model": "gemini:gemini-2.5-flash",
"messages": []any{map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "analyze this video"},
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://cdn.example.com/input.mov", "mime_type": "video/quicktime"}},
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "data:audio/wav;base64,UklGRg=="}},
},
}},
},
Candidate: store.RuntimeModelCandidate{
BaseURL: server.URL + "/v1beta/openai",
ProviderModelName: "gemini-2.5-flash",
ModelType: "chat",
Credentials: map[string]any{"apiKey": "gemini-key"},
},
})
if err != nil {
t.Fatalf("run gemini client: %v", err)
}
if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" {
t.Fatalf("Gemini OpenAI-compatible base URL should normalize to native endpoint, got %s", gotPath)
}
contents, _ := captured["contents"].([]any)
if len(contents) != 1 {
t.Fatalf("unexpected Gemini contents: %+v", captured)
}
turn, _ := contents[0].(map[string]any)
parts, _ := turn["parts"].([]any)
if len(parts) != 3 {
t.Fatalf("expected text, video, and audio parts, got %+v", turn)
}
video, _ := parts[1].(map[string]any)
videoFile, _ := video["fileData"].(map[string]any)
if videoFile["fileUri"] != "https://cdn.example.com/input.mov" || videoFile["mimeType"] != "video/quicktime" {
t.Fatalf("video_url should become Gemini fileData, got %+v", video)
}
audio, _ := parts[2].(map[string]any)
audioInline, _ := audio["inlineData"].(map[string]any)
if audioInline["mimeType"] != "audio/wav" || audioInline["data"] != "UklGRg==" {
t.Fatalf("audio data URL should become Gemini inlineData, got %+v", audio)
}
}
func TestGeminiClientChatRestoresToolContext(t *testing.T) {
var captured map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

View File

@ -5,8 +5,10 @@ import (
"context"
"encoding/json"
"fmt"
"mime"
"net/http"
"net/url"
"path"
"strings"
"time"
)
@ -58,6 +60,7 @@ func geminiURL(baseURL string, model string, apiKey string) string {
if base == "" {
base = "https://generativelanguage.googleapis.com"
}
base = strings.TrimSuffix(base, "/openai")
if strings.HasSuffix(base, "/v1beta") {
base = strings.TrimSuffix(base, "/v1beta")
}
@ -121,7 +124,7 @@ func geminiContentsFromMessages(body map[string]any) []any {
})
continue
}
parts := geminiTextParts(message["content"])
parts := geminiContentParts(message["content"])
if role == "assistant" {
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
toolCall, _ := rawToolCall.(map[string]any)
@ -157,7 +160,7 @@ func geminiRole(role string) string {
return "user"
}
func geminiTextParts(content any) []any {
func geminiContentParts(content any) []any {
parts := make([]any, 0)
switch typed := content.(type) {
case string:
@ -167,14 +170,146 @@ func geminiTextParts(content any) []any {
case []any:
for _, rawPart := range typed {
part, _ := rawPart.(map[string]any)
if text := stringFromAny(firstPresent(part["text"], part["content"])); strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
if len(part) == 0 {
continue
}
switch stringFromAny(part["type"]) {
case "text":
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
parts = append(parts, map[string]any{"text": text})
}
case "image_url":
if media := geminiMediaPart(part, "image_url", "image"); media != nil {
parts = append(parts, media)
}
case "video_url":
if media := geminiMediaPart(part, "video_url", "video"); media != nil {
parts = append(parts, media)
}
case "audio_url":
if media := geminiMediaPart(part, "audio_url", "audio"); media != nil {
parts = append(parts, media)
}
case "input_audio":
if media := geminiInputAudioPart(part); media != nil {
parts = append(parts, media)
}
default:
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
parts = append(parts, map[string]any{"text": text})
}
}
}
}
return parts
}
func geminiMediaPart(part map[string]any, key string, mediaType string) map[string]any {
nested := mapFromAny(part[key])
uri := firstNonEmptyString(nested["url"], part["url"], part[key])
if uri == "" {
return nil
}
mimeType := firstNonEmptyString(nested["mime_type"], nested["mimeType"], part["mime_type"], part["mimeType"])
return geminiMediaURLPart(uri, mimeType, mediaType)
}
func geminiInputAudioPart(part map[string]any) map[string]any {
audio := mapFromAny(part["input_audio"])
uri := firstNonEmptyString(audio["data"], audio["url"])
if uri == "" {
return nil
}
mimeType := firstNonEmptyString(audio["mime_type"], audio["mimeType"])
if mimeType == "" {
format := strings.ToLower(strings.TrimPrefix(stringFromAny(audio["format"]), "."))
if strings.Contains(format, "/") {
mimeType = format
} else if format == "mp3" {
mimeType = "audio/mpeg"
} else if format != "" {
mimeType = "audio/" + format
}
}
return geminiMediaURLPart(uri, mimeType, "audio")
}
func geminiMediaURLPart(uri string, explicitMimeType string, mediaType string) map[string]any {
if parsed := geminiDataURL(uri); parsed != nil {
return map[string]any{"inlineData": map[string]any{
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, parsed.mimeType), mediaType),
"data": parsed.data,
}}
}
return map[string]any{"fileData": map[string]any{
"fileUri": uri,
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, mimeFromURI(uri)), mediaType),
}}
}
type geminiParsedDataURL struct {
mimeType string
data string
}
func geminiDataURL(value string) *geminiParsedDataURL {
if !strings.HasPrefix(value, "data:") {
return nil
}
prefix, data, ok := strings.Cut(value, ",")
if !ok || !strings.Contains(prefix, ";base64") {
return nil
}
mimeType := strings.TrimPrefix(strings.Split(prefix, ";")[0], "data:")
if mimeType == "" {
mimeType = "application/octet-stream"
}
return &geminiParsedDataURL{mimeType: mimeType, data: data}
}
func mimeFromURI(value string) string {
pathValue := value
if parsed, err := url.Parse(value); err == nil && parsed.Path != "" {
pathValue = parsed.Path
}
extension := strings.ToLower(path.Ext(pathValue))
if extension == "" {
return ""
}
return mime.TypeByExtension(extension)
}
func geminiMediaMime(mimeType string, mediaType string) string {
normalized := strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0]))
switch mediaType {
case "image":
if strings.HasPrefix(normalized, "image/") && normalized != "image/svg+xml" {
return normalized
}
return "image/png"
case "video":
switch normalized {
case "video/x-msvideo":
return "video/avi"
case "video/quicktime", "video/mpeg", "video/mp4", "video/avi", "video/x-flv", "video/mpg", "video/webm", "video/wmv", "video/3gpp":
return normalized
default:
return "video/mp4"
}
case "audio":
switch normalized {
case "audio/x-wav", "audio/wave":
return "audio/wav"
case "audio/mpeg", "audio/mp3", "audio/wav", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", "audio/mp4", "audio/webm":
return normalized
default:
return "audio/mpeg"
}
default:
return "application/octet-stream"
}
}
func toolCallsSlice(value any) []any {
switch typed := value.(type) {
case []any:

View File

@ -129,6 +129,13 @@ func TestCoreLocalFlow(t *testing.T) {
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
t.Fatalf("promote smoke user: %v", err)
}
doJSON(t, server.URL, http.MethodPost, "/api/v1/auth/login", "", map[string]any{
"account": username,
"password": password,
}, http.StatusOK, &loginResponse)
if loginResponse.AccessToken == "" {
t.Fatal("admin login did not return access token")
}
var smokeGatewayUserID string
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
t.Fatalf("read smoke gateway user id: %v", err)
@ -1402,14 +1409,41 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
t.Fatalf("connect migration db: %v", err)
}
defer pool.Close()
if _, err := pool.Exec(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version text PRIMARY KEY,
applied_at timestamptz NOT NULL DEFAULT now()
);`); err != nil {
t.Fatalf("ensure schema migrations: %v", err)
}
for _, migrationPath := range migrationFiles {
version := strings.TrimSuffix(filepath.Base(migrationPath), filepath.Ext(migrationPath))
var exists bool
if err := pool.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&exists); err != nil {
t.Fatalf("check migration %s: %v", filepath.Base(migrationPath), err)
}
if exists {
continue
}
migration, err := os.ReadFile(migrationPath)
if err != nil {
t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err)
}
if _, err := pool.Exec(ctx, string(migration)); err != nil {
tx, err := pool.Begin(ctx)
if err != nil {
t.Fatalf("begin migration %s: %v", filepath.Base(migrationPath), err)
}
if _, err := tx.Exec(ctx, string(migration)); err != nil {
_ = tx.Rollback(ctx)
t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err)
}
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations(version) VALUES($1)", version); err != nil {
_ = tx.Rollback(ctx)
t.Fatalf("record migration %s: %v", filepath.Base(migrationPath), err)
}
if err := tx.Commit(ctx); err != nil {
t.Fatalf("commit migration %s: %v", filepath.Base(migrationPath), err)
}
}
}

View File

@ -160,6 +160,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
normalizedModelType := modelType
attemptNo := task.AttemptCount
var firstPreprocessing parameterPreprocessingLog
var walletReservations []store.WalletBillingReservation
walletReservationFinalized := false
defer func() {
if !walletReservationFinalized && len(walletReservations) > 0 {
_ = s.store.ReleaseTaskBillingReservations(context.WithoutCancel(ctx), walletReservations, "task_not_settled")
}
}()
if len(candidates) > 0 {
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0])
firstCandidateBody = preprocessing.Body
@ -191,15 +198,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
return Result{}, err
}
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
if errors.Is(err, store.ErrInsufficientWalletBalance) {
var reserveErr error
walletReservations, reserveErr = s.store.ReserveTaskBilling(ctx, task, user, estimatedBillings)
if reserveErr != nil {
if errors.Is(reserveErr, store.ErrInsufficientWalletBalance) {
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
Task: task,
Body: firstCandidateBody,
Candidate: &candidates[0],
AttemptNo: attemptNo + 1,
Code: "insufficient_balance",
Cause: err,
Cause: reserveErr,
Simulated: task.RunMode == "simulation",
Scope: "wallet_balance",
Reason: "wallet_balance_check_failed",
@ -207,13 +216,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
Preprocessing: &firstPreprocessing,
ModelType: normalizedModelType,
})
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing))
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", reserveErr.Error(), task.RunMode == "simulation", reserveErr, parameterPreprocessingMetrics(firstPreprocessing))
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, err
return Result{Task: failed, Output: failed.Result}, reserveErr
}
return Result{}, err
return Result{}, reserveErr
}
}
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil {
@ -286,9 +295,18 @@ candidatesLoop:
if finishErr != nil {
return Result{}, finishErr
}
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
return Result{}, settleErr
if finished.FinalChargeAmount > 0 {
walletReservationFinalized = true
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
return Result{}, settleErr
}
} else if len(walletReservations) > 0 {
if releaseErr := s.store.ReleaseTaskBillingReservations(ctx, walletReservations, "task_billing_zero"); releaseErr != nil {
return Result{}, releaseErr
}
walletReservationFinalized = true
}
walletReservationFinalized = true
if finished.FinalChargeAmount > 0 {
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
"amount": finished.FinalChargeAmount,
@ -695,6 +713,11 @@ func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated boo
if key == "" {
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
}
provider := strings.ToLower(strings.TrimSpace(candidate.Provider))
baseURL := strings.ToLower(strings.TrimSpace(candidate.BaseURL))
if key == "google-gemini" || provider == "gemini" || provider == "google-gemini" || provider == "gemini-openai" || strings.Contains(baseURL, "generativelanguage.googleapis.com") {
key = "gemini"
}
if client, ok := s.clients[key]; ok {
return client
}

View File

@ -0,0 +1,34 @@
package runner
import (
"context"
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type namedClient string
func (namedClient) Run(context.Context, clients.Request) (clients.Response, error) {
return clients.Response{}, nil
}
func TestClientForMapsGoogleGeminiSpecToGeminiClient(t *testing.T) {
service := &Service{clients: map[string]clients.Client{
"gemini": namedClient("gemini"),
"openai": namedClient("openai"),
}}
candidates := []store.RuntimeModelCandidate{
{SpecType: "google-gemini"},
{SpecType: "openai", Provider: "gemini-openai"},
{SpecType: "openai", BaseURL: "https://generativelanguage.googleapis.com/v1beta/openai"},
}
for _, candidate := range candidates {
client := service.clientFor(candidate, false)
if client != namedClient("gemini") {
t.Fatalf("Gemini candidate should use gemini client, candidate=%+v got %T %[2]v", candidate, client)
}
}
}

View File

@ -1,38 +0,0 @@
package runner
import (
"context"
"fmt"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
func (s *Service) ensureWalletBalance(ctx context.Context, user *auth.User, billings []any) error {
amounts := map[string]float64{}
for _, raw := range billings {
line, _ := raw.(map[string]any)
if line == nil {
continue
}
currency := strings.TrimSpace(stringFromAny(line["currency"]))
if currency == "" {
currency = "resource"
}
amounts[currency] = roundPrice(amounts[currency] + floatFromAny(line["amount"]))
}
for currency, amount := range amounts {
if amount <= 0 {
continue
}
availability, err := s.store.WalletAvailability(ctx, user, currency, amount)
if err != nil {
return err
}
if !availability.Enough {
return fmt.Errorf("%w: required %.6f %s, available %.6f", store.ErrInsufficientWalletBalance, amount, currency, availability.AvailableAmount)
}
}
return nil
}

View File

@ -0,0 +1,202 @@
package runner
import (
"context"
"errors"
"io"
"log/slog"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type walletExecuteMockClient struct {
calls atomic.Int32
}
func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) {
client.calls.Add(1)
return clients.Response{
Result: map[string]any{"mock": true},
RequestID: "mock-wallet-execute",
}, nil
}
func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) {
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
if databaseURL == "" {
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test")
}
ctx := context.Background()
db, err := store.Connect(ctx, databaseURL)
if err != nil {
t.Fatalf("connect store: %v", err)
}
t.Cleanup(db.Close)
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{
TenantKey: "wallet-execute-" + suffix,
Name: "Wallet Execute Test " + suffix,
})
if err != nil {
t.Fatalf("create tenant: %v", err)
}
t.Cleanup(func() {
_ = db.DeleteTenant(context.Background(), tenant.ID)
})
gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{
UserKey: "wallet-execute-user-" + suffix,
Username: "wallet_execute_" + suffix,
GatewayTenantID: tenant.ID,
TenantKey: tenant.TenantKey,
Roles: []string{"user"},
})
if err != nil {
t.Fatalf("create gateway user: %v", err)
}
t.Cleanup(func() {
_ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID)
})
platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{
Provider: "mock",
PlatformKey: "wallet-execute-mock-" + suffix,
Name: "Wallet Execute Mock " + suffix,
AuthType: "none",
Config: map[string]any{"specType": "mock"},
Status: "enabled",
Priority: 1,
})
if err != nil {
t.Fatalf("create mock platform: %v", err)
}
t.Cleanup(func() {
_ = db.DeletePlatform(context.Background(), platform.ID)
})
if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{
PlatformID: platform.ID,
ModelName: "mock-wallet-image",
ProviderModelName: "mock-wallet-image",
ModelType: store.StringList{"image_generate"},
DisplayName: "Mock Wallet Image",
BillingConfig: map[string]any{
"image": map[string]any{"basePrice": 10},
},
}); err != nil {
t.Fatalf("create mock platform model: %v", err)
}
user := &auth.User{
ID: gatewayUser.ID,
Source: "gateway",
GatewayUserID: gatewayUser.ID,
GatewayTenantID: tenant.ID,
TenantKey: tenant.TenantKey,
Roles: gatewayUser.Roles,
}
if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{
GatewayUserID: gatewayUser.ID,
Currency: "resource",
Balance: 10,
Reason: "seed wallet execute test",
}); err != nil {
t.Fatalf("seed wallet balance: %v", err)
}
tasks := make([]store.GatewayTask, 0, 2)
for i := 0; i < 2; i++ {
task, err := db.CreateTask(ctx, store.CreateTaskInput{
Kind: "images.generations",
Model: "mock-wallet-image",
Request: map[string]any{
"count": 1,
"prompt": "wallet execute test",
},
}, user)
if err != nil {
t.Fatalf("create task: %v", err)
}
tasks = append(tasks, task)
}
mockClient := &walletExecuteMockClient{}
service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil)))
service.clients["mock"] = mockClient
type executeResult struct {
result Result
err error
}
results := make(chan executeResult, len(tasks))
var wg sync.WaitGroup
for _, task := range tasks {
task := task
wg.Add(1)
go func() {
defer wg.Done()
result, err := service.Execute(ctx, task, user)
results <- executeResult{result: result, err: err}
}()
}
wg.Wait()
close(results)
successCount := 0
insufficientCount := 0
for item := range results {
if item.err == nil {
successCount++
if item.result.Task.Status != "succeeded" {
t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status)
}
if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) {
t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount)
}
continue
}
if errors.Is(item.err, store.ErrInsufficientWalletBalance) {
insufficientCount++
if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" {
t.Fatalf("insufficient execution task = %+v", item.result.Task)
}
continue
}
t.Fatalf("unexpected execute error: %v", item.err)
}
if successCount != 1 || insufficientCount != 1 {
t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
}
if got := mockClient.calls.Load(); got != 1 {
t.Fatalf("mock client calls = %d, want 1", got)
}
summary, err := db.GetWalletSummary(ctx, user, "resource")
if err != nil {
t.Fatalf("get wallet summary: %v", err)
}
account := summary.PrimaryAccount
if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) {
t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent)
}
}
func walletExecuteFloatNear(a float64, b float64) bool {
delta := a - b
if delta < 0 {
delta = -delta
}
return delta < 0.000001
}

View File

@ -3,13 +3,13 @@ package store
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type TaskListFilter struct {
@ -687,14 +687,15 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
if currency == "" || currency == "mixed" {
currency = "resource"
}
metadata, _ := json.Marshal(map[string]any{
metadataMap := map[string]any{
"taskId": task.ID,
"kind": task.Kind,
"model": task.Model,
"resolvedModel": task.ResolvedModel,
"billings": task.Billings,
"billingSummary": task.BillingSummary,
})
}
metadata, _ := json.Marshal(metadataMap)
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_accounts (
@ -706,42 +707,85 @@ ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
return err
}
var exists bool
var accountID string
var balanceBefore float64
var frozenBefore float64
var gatewayTenantID string
if err := tx.QueryRow(ctx, `
SELECT id::text, balance::float8, frozen_balance::float8, COALESCE(gateway_tenant_id::text, '')
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = $2
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &frozenBefore, &gatewayTenantID); err != nil {
return err
}
if err := tx.QueryRow(ctx, `
SELECT EXISTS (
SELECT 1
FROM gateway_wallet_transactions t
JOIN gateway_wallet_accounts a ON a.id = t.account_id
WHERE a.gateway_user_id = $1::uuid
AND a.currency = $2
AND t.idempotency_key = $3
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
FROM gateway_wallet_transactions
WHERE account_id = $1::uuid
AND idempotency_key = $2
)`, accountID, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
return err
}
if exists {
return nil
}
var accountID string
var balanceBefore float64
var gatewayTenantID string
if err := tx.QueryRow(ctx, `
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = $2
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
amount := roundMoney(task.FinalChargeAmount)
reservationKey, reservedAmount, err := activeWalletReservation(ctx, tx, accountID, task.ID)
if err != nil {
return err
}
amount := roundMoney(task.FinalChargeAmount)
reservedAmount = roundMoney(reservedAmount)
spendableForTask := roundMoney(balanceBefore - frozenBefore + reservedAmount)
if spendableForTask+0.000001 < amount {
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, currency, spendableForTask)
}
balanceAfter := roundMoney(balanceBefore - amount)
frozenAfter := roundMoney(frozenBefore - reservedAmount)
if frozenAfter < 0 {
frozenAfter = 0
}
if _, err := tx.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET balance = $2,
total_spent = total_spent + $3,
frozen_balance = $4,
updated_at = now()
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
WHERE id = $1::uuid`, accountID, balanceAfter, amount, frozenAfter); err != nil {
return err
}
_, err := tx.Exec(ctx, `
if reservedAmount > 0 {
releaseMetadata, _ := json.Marshal(map[string]any{
"taskId": task.ID,
"reason": "task_billing_settled",
"reserved": reservedAmount,
"frozenBefore": roundMoney(frozenBefore),
"frozenAfter": frozenAfter,
})
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
)
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, reservedAmount, roundMoney(balanceBefore), roundMoney(balanceBefore), billingReservationReleaseIdempotencyKey(reservationKey), task.ID, string(releaseMetadata)); err != nil {
return err
}
}
billingMetadata := mergeObjects(metadataMap, map[string]any{
"reservedAmount": reservedAmount,
"frozenBefore": roundMoney(frozenBefore),
"frozenAfter": frozenAfter,
})
metadata, _ = json.Marshal(billingMetadata)
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
@ -750,11 +794,10 @@ VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
)`,
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
return nil
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)); err != nil {
return err
}
return err
return nil
})
}

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
@ -92,6 +93,16 @@ type WalletAdjustmentResult struct {
Transaction GatewayWalletTransaction `json:"transaction"`
}
type WalletBillingReservation struct {
TaskID string `json:"taskId"`
AccountID string `json:"accountId"`
GatewayUserID string `json:"gatewayUserId"`
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
Currency string `json:"currency"`
Amount float64 `json:"amount"`
IdempotencyKey string `json:"idempotencyKey"`
}
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
@ -115,6 +126,223 @@ func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currenc
return result, nil
}
func (s *Store) ReserveTaskBilling(ctx context.Context, task GatewayTask, user *auth.User, billings []any) ([]WalletBillingReservation, error) {
gatewayUserID := taskGatewayUserID(task, user)
if gatewayUserID == "" {
return nil, nil
}
taskID := strings.TrimSpace(task.ID)
if taskID == "" {
return nil, fmt.Errorf("task id is required for wallet reservation")
}
amounts := walletBillingAmounts(billings)
if len(amounts) == 0 {
return nil, nil
}
reservations := make([]WalletBillingReservation, 0, len(amounts))
err := pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
for currency, rawAmount := range amounts {
amount := roundMoney(rawAmount)
if amount <= 0 {
continue
}
account, err := s.ensureWalletAccount(ctx, tx, gatewayUserID, currency)
if err != nil {
return err
}
locked, err := lockWalletAccount(ctx, tx, account.ID)
if err != nil {
return err
}
activeKey, activeAmount, err := activeWalletReservation(ctx, tx, locked.ID, taskID)
if err != nil {
return err
}
if activeAmount > 0 {
reservation := WalletBillingReservation{
TaskID: taskID,
AccountID: locked.ID,
GatewayUserID: gatewayUserID,
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
Currency: locked.Currency,
Amount: activeAmount,
IdempotencyKey: activeKey,
}
reservations = append(reservations, reservation)
continue
}
sequence, err := nextWalletReservationSequence(ctx, tx, locked.ID, taskID)
if err != nil {
return err
}
key := billingReservationIdempotencyKey(taskID, locked.Currency, sequence)
reservation := WalletBillingReservation{
TaskID: taskID,
AccountID: locked.ID,
GatewayUserID: gatewayUserID,
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
Currency: locked.Currency,
Amount: amount,
IdempotencyKey: key,
}
available := roundMoney(locked.Balance - locked.FrozenBalance)
if available+0.000001 < amount {
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, locked.Currency, available)
}
frozenAfter := roundMoney(locked.FrozenBalance + amount)
if _, err := tx.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET frozen_balance = $2,
updated_at = now()
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
return err
}
metadata, _ := json.Marshal(map[string]any{
"taskId": taskID,
"kind": task.Kind,
"model": task.Model,
"reserved": amount,
"balance": roundMoney(locked.Balance),
"frozenBefore": roundMoney(locked.FrozenBalance),
"frozenAfter": frozenAfter,
})
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'reserve',
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
)`,
locked.ID,
firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
gatewayUserID,
amount,
roundMoney(locked.Balance),
roundMoney(locked.Balance),
key,
taskID,
string(metadata),
); err != nil {
return err
}
reservations = append(reservations, reservation)
}
return nil
})
if err != nil {
return nil, err
}
return reservations, err
}
func (s *Store) ReleaseTaskBillingReservations(ctx context.Context, reservations []WalletBillingReservation, reason string) error {
if len(reservations) == 0 {
return nil
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "task_not_settled"
}
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
for _, reservation := range reservations {
if reservation.Amount <= 0 || strings.TrimSpace(reservation.AccountID) == "" {
continue
}
reserveKey := strings.TrimSpace(reservation.IdempotencyKey)
if reserveKey == "" {
reserveKey = billingReservationIdempotencyKey(reservation.TaskID, reservation.Currency, 1)
}
releaseKey := billingReservationReleaseIdempotencyKey(reserveKey)
locked, err := lockWalletAccount(ctx, tx, reservation.AccountID)
if err != nil {
if err == pgx.ErrNoRows {
continue
}
return err
}
var alreadyReleased bool
if err := tx.QueryRow(ctx, `
SELECT EXISTS (
SELECT 1
FROM gateway_wallet_transactions
WHERE account_id = $1::uuid
AND idempotency_key = $2
)`, reservation.AccountID, releaseKey).Scan(&alreadyReleased); err != nil {
return err
}
if alreadyReleased {
continue
}
var storedReservedAmount float64
if err := tx.QueryRow(ctx, `
SELECT COALESCE((
SELECT amount::float8
FROM gateway_wallet_transactions
WHERE account_id = $1::uuid
AND idempotency_key = $2
AND transaction_type = 'reserve'
LIMIT 1
), 0)::float8`, reservation.AccountID, reserveKey).Scan(&storedReservedAmount); err != nil {
return err
}
if storedReservedAmount <= 0 {
continue
}
amount := roundMoney(storedReservedAmount)
frozenAfter := roundMoney(locked.FrozenBalance - amount)
if frozenAfter < 0 {
frozenAfter = 0
}
if _, err := tx.Exec(ctx, `
UPDATE gateway_wallet_accounts
SET frozen_balance = $2,
updated_at = now()
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
return err
}
metadata, _ := json.Marshal(map[string]any{
"taskId": reservation.TaskID,
"reason": reason,
"reserved": amount,
"frozenBefore": roundMoney(locked.FrozenBalance),
"frozenAfter": frozenAfter,
})
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_wallet_transactions (
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
)
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
locked.ID,
locked.GatewayTenantID,
locked.GatewayUserID,
amount,
roundMoney(locked.Balance),
roundMoney(locked.Balance),
releaseKey,
reservation.TaskID,
string(metadata),
); err != nil {
return err
}
}
return nil
})
}
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
@ -465,6 +693,124 @@ WHERE gateway_user_id = $1::uuid
return account, nil
}
func lockWalletAccount(ctx context.Context, tx pgx.Tx, accountID string) (GatewayWalletAccount, error) {
return scanWalletAccount(tx.QueryRow(ctx, `
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
total_spent::float8, status, metadata, created_at, updated_at
FROM gateway_wallet_accounts
WHERE id = $1::uuid
FOR UPDATE`, accountID))
}
func activeWalletReservation(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (string, float64, error) {
var key string
var amount float64
err := tx.QueryRow(ctx, `
SELECT COALESCE(t.idempotency_key, ''), t.amount::float8
FROM gateway_wallet_transactions t
WHERE t.account_id = $1::uuid
AND t.reference_type = 'gateway_task'
AND t.reference_id = $2
AND t.transaction_type = 'reserve'
AND COALESCE(t.idempotency_key, '') <> ''
AND NOT EXISTS (
SELECT 1
FROM gateway_wallet_transactions r
WHERE r.account_id = t.account_id
AND r.transaction_type = 'release'
AND r.idempotency_key = t.idempotency_key || ':release'
)
ORDER BY t.created_at DESC
LIMIT 1`, accountID, taskID).Scan(&key, &amount)
if err == pgx.ErrNoRows {
return "", 0, nil
}
if err != nil {
return "", 0, err
}
return key, roundMoney(amount), nil
}
func nextWalletReservationSequence(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (int, error) {
var count int
if err := tx.QueryRow(ctx, `
SELECT COUNT(*)::int
FROM gateway_wallet_transactions
WHERE account_id = $1::uuid
AND reference_type = 'gateway_task'
AND reference_id = $2
AND transaction_type = 'reserve'`, accountID, taskID).Scan(&count); err != nil {
return 0, err
}
return count + 1, nil
}
func walletBillingAmounts(billings []any) map[string]float64 {
amounts := map[string]float64{}
for _, raw := range billings {
line, _ := raw.(map[string]any)
if line == nil {
continue
}
amount := roundMoney(walletFloat(line["amount"]))
if amount <= 0 {
continue
}
currency := normalizeWalletCurrency(walletString(line["currency"]))
amounts[currency] = roundMoney(amounts[currency] + amount)
}
return amounts
}
func taskGatewayUserID(task GatewayTask, user *auth.User) string {
return firstNonEmpty(strings.TrimSpace(task.GatewayUserID), localGatewayUserID(user))
}
func billingReservationIdempotencyKey(taskID string, currency string, sequence int) string {
if sequence <= 0 {
sequence = 1
}
return "task:" + strings.TrimSpace(taskID) + ":wallet-reservation:" + normalizeWalletCurrency(currency) + ":" + strconv.Itoa(sequence)
}
func billingReservationReleaseIdempotencyKey(reservationKey string) string {
return strings.TrimSpace(reservationKey) + ":release"
}
func walletString(value any) string {
if text, ok := value.(string); ok {
return strings.TrimSpace(text)
}
return ""
}
func walletFloat(value any) float64 {
switch typed := value.(type) {
case float64:
return typed
case float32:
return float64(typed)
case int:
return float64(typed)
case int64:
return float64(typed)
case json.Number:
next, _ := typed.Float64()
return next
case string:
next := strings.TrimSpace(typed)
if next == "" {
return 0
}
parsed, _ := strconv.ParseFloat(next, 64)
return parsed
default:
return 0
}
}
func normalizeWalletCurrency(currency string) string {
currency = strings.TrimSpace(currency)
if currency == "" {

View File

@ -0,0 +1,171 @@
package store
import (
"context"
"errors"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
)
func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) {
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
if databaseURL == "" {
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test")
}
ctx := context.Background()
db, err := Connect(ctx, databaseURL)
if err != nil {
t.Fatalf("connect store: %v", err)
}
defer db.Close()
tenantID, userID := seedWalletReservationUser(t, ctx, db)
if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{
GatewayUserID: userID,
Currency: "resource",
Balance: 10,
Reason: "seed wallet reservation test",
}); err != nil {
t.Fatalf("seed wallet balance: %v", err)
}
firstTaskID := newWalletReservationTestUUID(t, ctx, db)
secondTaskID := newWalletReservationTestUUID(t, ctx, db)
billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}}
user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID}
tasks := []GatewayTask{
{ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"},
{ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"},
}
type reserveResult struct {
reservations []WalletBillingReservation
err error
}
results := make(chan reserveResult, len(tasks))
var wg sync.WaitGroup
for _, task := range tasks {
task := task
wg.Add(1)
go func() {
defer wg.Done()
reservations, err := db.ReserveTaskBilling(ctx, task, user, billings)
results <- reserveResult{reservations: reservations, err: err}
}()
}
wg.Wait()
close(results)
var successReservations []WalletBillingReservation
successCount := 0
insufficientCount := 0
for result := range results {
if result.err == nil {
successCount++
successReservations = result.reservations
continue
}
if errors.Is(result.err, ErrInsufficientWalletBalance) {
insufficientCount++
continue
}
t.Fatalf("unexpected reservation error: %v", result.err)
}
if successCount != 1 || insufficientCount != 1 {
t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
}
if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) {
t.Fatalf("unexpected successful reservations: %+v", successReservations)
}
balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID)
if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) {
t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent)
}
settleTask := GatewayTask{
ID: successReservations[0].TaskID,
GatewayUserID: userID,
GatewayTenantID: tenantID,
Kind: "images.generations",
Model: "mock-image",
ResolvedModel: "mock-image",
Billings: billings,
BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)},
FinalChargeAmount: 10,
}
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
t.Fatalf("settle reserved task billing: %v", err)
}
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
t.Fatalf("settle reserved task billing should be idempotent: %v", err)
}
balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID)
if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) {
t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent)
}
}
func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) {
t.Helper()
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
var tenantID string
if err := db.pool.QueryRow(ctx, `
INSERT INTO gateway_tenants (tenant_key, name)
VALUES ($1, $2)
RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil {
t.Fatalf("insert test tenant: %v", err)
}
var userID string
if err := db.pool.QueryRow(ctx, `
INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles)
VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb)
RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil {
t.Fatalf("insert test user: %v", err)
}
t.Cleanup(func() {
cleanupCtx := context.Background()
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID)
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID)
})
return tenantID, userID
}
func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string {
t.Helper()
var id string
if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil {
t.Fatalf("generate uuid: %v", err)
}
return id
}
func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) {
t.Helper()
var balance float64
var frozen float64
var spent float64
if err := db.pool.QueryRow(ctx, `
SELECT balance::float8, frozen_balance::float8, total_spent::float8
FROM gateway_wallet_accounts
WHERE gateway_user_id = $1::uuid
AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil {
t.Fatalf("read wallet account: %v", err)
}
return balance, frozen, spent
}
func walletFloatNear(a float64, b float64) bool {
delta := a - b
if delta < 0 {
delta = -delta
}
return delta < 0.000001
}

View File

@ -684,11 +684,13 @@ export interface VideoGenerationContent {
};
video_url?: {
url: string;
mime_type?: string;
refer_type?: 'feature' | 'base';
keep_original_sound?: 'yes' | 'no';
};
audio_url?: {
url: string;
mime_type?: string;
};
role?: VideoGenerationContentRole;
shot_index?: number;

View File

@ -32,8 +32,8 @@ export interface PlaygroundUpload {
export type OpenAIChatContentPart =
| { type: 'text'; text: string }
| { type: 'image_url'; image_url: { url: string } }
| { type: 'video_url'; video_url: { url: string } }
| { type: 'audio_url'; audio_url: { url: string } }
| { type: 'video_url'; video_url: { mime_type?: string; url: string } }
| { type: 'audio_url'; audio_url: { mime_type?: string; url: string } }
| { type: 'file_url'; file_url: { filename: string; url: string } };
export const mediaUploadAccept = 'image/*,video/*,audio/*';
@ -518,11 +518,17 @@ export function openAIContentFromPromptAndUploads(prompt: string, uploads: Playg
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
if (!item.url) return undefined;
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
if (item.kind === 'video') return { type: 'video_url', video_url: { url: item.url } };
if (item.kind === 'audio') return { type: 'audio_url', audio_url: { url: item.url } };
if (item.kind === 'video') return { type: 'video_url', video_url: mediaURLPayload(item) };
if (item.kind === 'audio') return { type: 'audio_url', audio_url: mediaURLPayload(item) };
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
}
function mediaURLPayload(item: PlaygroundUpload) {
const payload: { mime_type?: string; url: string } = { url: item.url };
if (item.contentType) payload.mime_type = item.contentType;
return payload;
}
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
const payload: Record<string, string | string[]> = {};
@ -570,10 +576,10 @@ function videoGenerationContentFromUpload(item: PlaygroundUpload): VideoGenerati
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
}
if (item.kind === 'video') {
return { type: 'video_url', role: 'reference_video', video_url: { url: item.url, refer_type: 'feature' } };
return { type: 'video_url', role: 'reference_video', video_url: { ...mediaURLPayload(item), refer_type: 'feature' } };
}
if (item.kind === 'audio') {
return { type: 'audio_url', role: 'reference_audio', audio_url: { url: item.url } };
return { type: 'audio_url', role: 'reference_audio', audio_url: mediaURLPayload(item) };
}
return undefined;
}