diff --git a/apps/api/internal/clients/clients_test.go b/apps/api/internal/clients/clients_test.go index 630a573..c521227 100644 --- a/apps/api/internal/clients/clients_test.go +++ b/apps/api/internal/clients/clients_test.go @@ -569,6 +569,70 @@ func TestGeminiClientChatContract(t *testing.T) { } } +func TestGeminiClientChatConvertsMediaContentParts(t *testing.T) { + var captured map[string]any + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + if err := json.NewDecoder(r.Body).Decode(&captured); err != nil { + t.Fatalf("decode request: %v", err) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "candidates": []any{map[string]any{ + "content": map[string]any{"parts": []any{map[string]any{"text": "video ok"}}}, + }}, + }) + })) + defer server.Close() + + _, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{ + Kind: "chat.completions", + Model: "gemini:gemini-2.5-flash", + Body: map[string]any{ + "model": "gemini:gemini-2.5-flash", + "messages": []any{map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "analyze this video"}, + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://cdn.example.com/input.mov", "mime_type": "video/quicktime"}}, + map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "data:audio/wav;base64,UklGRg=="}}, + }, + }}, + }, + Candidate: store.RuntimeModelCandidate{ + BaseURL: server.URL + "/v1beta/openai", + ProviderModelName: "gemini-2.5-flash", + ModelType: "chat", + Credentials: map[string]any{"apiKey": "gemini-key"}, + }, + }) + if err != nil { + t.Fatalf("run gemini client: %v", err) + } + if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" { + t.Fatalf("Gemini OpenAI-compatible base URL should normalize to native endpoint, got %s", gotPath) + } + contents, _ := captured["contents"].([]any) + if len(contents) != 1 { + t.Fatalf("unexpected Gemini contents: %+v", captured) + } + turn, _ := contents[0].(map[string]any) + parts, _ := turn["parts"].([]any) + if len(parts) != 3 { + t.Fatalf("expected text, video, and audio parts, got %+v", turn) + } + video, _ := parts[1].(map[string]any) + videoFile, _ := video["fileData"].(map[string]any) + if videoFile["fileUri"] != "https://cdn.example.com/input.mov" || videoFile["mimeType"] != "video/quicktime" { + t.Fatalf("video_url should become Gemini fileData, got %+v", video) + } + audio, _ := parts[2].(map[string]any) + audioInline, _ := audio["inlineData"].(map[string]any) + if audioInline["mimeType"] != "audio/wav" || audioInline["data"] != "UklGRg==" { + t.Fatalf("audio data URL should become Gemini inlineData, got %+v", audio) + } +} + func TestGeminiClientChatRestoresToolContext(t *testing.T) { var captured map[string]any server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/apps/api/internal/clients/gemini.go b/apps/api/internal/clients/gemini.go index 802797a..2649bde 100644 --- a/apps/api/internal/clients/gemini.go +++ b/apps/api/internal/clients/gemini.go @@ -5,8 +5,10 @@ import ( "context" "encoding/json" "fmt" + "mime" "net/http" "net/url" + "path" "strings" "time" ) @@ -58,6 +60,7 @@ func geminiURL(baseURL string, model string, apiKey string) string { if base == "" { base = "https://generativelanguage.googleapis.com" } + base = strings.TrimSuffix(base, "/openai") if strings.HasSuffix(base, "/v1beta") { base = strings.TrimSuffix(base, "/v1beta") } @@ -121,7 +124,7 @@ func geminiContentsFromMessages(body map[string]any) []any { }) continue } - parts := geminiTextParts(message["content"]) + parts := geminiContentParts(message["content"]) if role == "assistant" { for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) { toolCall, _ := rawToolCall.(map[string]any) @@ -157,7 +160,7 @@ func geminiRole(role string) string { return "user" } -func geminiTextParts(content any) []any { +func geminiContentParts(content any) []any { parts := make([]any, 0) switch typed := content.(type) { case string: @@ -167,14 +170,146 @@ func geminiTextParts(content any) []any { case []any: for _, rawPart := range typed { part, _ := rawPart.(map[string]any) - if text := stringFromAny(firstPresent(part["text"], part["content"])); strings.TrimSpace(text) != "" { - parts = append(parts, map[string]any{"text": text}) + if len(part) == 0 { + continue + } + switch stringFromAny(part["type"]) { + case "text": + if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" { + parts = append(parts, map[string]any{"text": text}) + } + case "image_url": + if media := geminiMediaPart(part, "image_url", "image"); media != nil { + parts = append(parts, media) + } + case "video_url": + if media := geminiMediaPart(part, "video_url", "video"); media != nil { + parts = append(parts, media) + } + case "audio_url": + if media := geminiMediaPart(part, "audio_url", "audio"); media != nil { + parts = append(parts, media) + } + case "input_audio": + if media := geminiInputAudioPart(part); media != nil { + parts = append(parts, media) + } + default: + if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" { + parts = append(parts, map[string]any{"text": text}) + } } } } return parts } +func geminiMediaPart(part map[string]any, key string, mediaType string) map[string]any { + nested := mapFromAny(part[key]) + uri := firstNonEmptyString(nested["url"], part["url"], part[key]) + if uri == "" { + return nil + } + mimeType := firstNonEmptyString(nested["mime_type"], nested["mimeType"], part["mime_type"], part["mimeType"]) + return geminiMediaURLPart(uri, mimeType, mediaType) +} + +func geminiInputAudioPart(part map[string]any) map[string]any { + audio := mapFromAny(part["input_audio"]) + uri := firstNonEmptyString(audio["data"], audio["url"]) + if uri == "" { + return nil + } + mimeType := firstNonEmptyString(audio["mime_type"], audio["mimeType"]) + if mimeType == "" { + format := strings.ToLower(strings.TrimPrefix(stringFromAny(audio["format"]), ".")) + if strings.Contains(format, "/") { + mimeType = format + } else if format == "mp3" { + mimeType = "audio/mpeg" + } else if format != "" { + mimeType = "audio/" + format + } + } + return geminiMediaURLPart(uri, mimeType, "audio") +} + +func geminiMediaURLPart(uri string, explicitMimeType string, mediaType string) map[string]any { + if parsed := geminiDataURL(uri); parsed != nil { + return map[string]any{"inlineData": map[string]any{ + "mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, parsed.mimeType), mediaType), + "data": parsed.data, + }} + } + return map[string]any{"fileData": map[string]any{ + "fileUri": uri, + "mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, mimeFromURI(uri)), mediaType), + }} +} + +type geminiParsedDataURL struct { + mimeType string + data string +} + +func geminiDataURL(value string) *geminiParsedDataURL { + if !strings.HasPrefix(value, "data:") { + return nil + } + prefix, data, ok := strings.Cut(value, ",") + if !ok || !strings.Contains(prefix, ";base64") { + return nil + } + mimeType := strings.TrimPrefix(strings.Split(prefix, ";")[0], "data:") + if mimeType == "" { + mimeType = "application/octet-stream" + } + return &geminiParsedDataURL{mimeType: mimeType, data: data} +} + +func mimeFromURI(value string) string { + pathValue := value + if parsed, err := url.Parse(value); err == nil && parsed.Path != "" { + pathValue = parsed.Path + } + extension := strings.ToLower(path.Ext(pathValue)) + if extension == "" { + return "" + } + return mime.TypeByExtension(extension) +} + +func geminiMediaMime(mimeType string, mediaType string) string { + normalized := strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0])) + switch mediaType { + case "image": + if strings.HasPrefix(normalized, "image/") && normalized != "image/svg+xml" { + return normalized + } + return "image/png" + case "video": + switch normalized { + case "video/x-msvideo": + return "video/avi" + case "video/quicktime", "video/mpeg", "video/mp4", "video/avi", "video/x-flv", "video/mpg", "video/webm", "video/wmv", "video/3gpp": + return normalized + default: + return "video/mp4" + } + case "audio": + switch normalized { + case "audio/x-wav", "audio/wave": + return "audio/wav" + case "audio/mpeg", "audio/mp3", "audio/wav", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", "audio/mp4", "audio/webm": + return normalized + default: + return "audio/mpeg" + } + default: + return "application/octet-stream" + } +} + func toolCallsSlice(value any) []any { switch typed := value.(type) { case []any: diff --git a/apps/api/internal/httpapi/core_flow_integration_test.go b/apps/api/internal/httpapi/core_flow_integration_test.go index e3a9995..914599e 100644 --- a/apps/api/internal/httpapi/core_flow_integration_test.go +++ b/apps/api/internal/httpapi/core_flow_integration_test.go @@ -129,6 +129,13 @@ func TestCoreLocalFlow(t *testing.T) { if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil { t.Fatalf("promote smoke user: %v", err) } + doJSON(t, server.URL, http.MethodPost, "/api/v1/auth/login", "", map[string]any{ + "account": username, + "password": password, + }, http.StatusOK, &loginResponse) + if loginResponse.AccessToken == "" { + t.Fatal("admin login did not return access token") + } var smokeGatewayUserID string if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil { t.Fatalf("read smoke gateway user id: %v", err) @@ -1402,14 +1409,41 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) { t.Fatalf("connect migration db: %v", err) } defer pool.Close() + if _, err := pool.Exec(ctx, ` +CREATE TABLE IF NOT EXISTS schema_migrations ( + version text PRIMARY KEY, + applied_at timestamptz NOT NULL DEFAULT now() +);`); err != nil { + t.Fatalf("ensure schema migrations: %v", err) + } for _, migrationPath := range migrationFiles { + version := strings.TrimSuffix(filepath.Base(migrationPath), filepath.Ext(migrationPath)) + var exists bool + if err := pool.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&exists); err != nil { + t.Fatalf("check migration %s: %v", filepath.Base(migrationPath), err) + } + if exists { + continue + } migration, err := os.ReadFile(migrationPath) if err != nil { t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err) } - if _, err := pool.Exec(ctx, string(migration)); err != nil { + tx, err := pool.Begin(ctx) + if err != nil { + t.Fatalf("begin migration %s: %v", filepath.Base(migrationPath), err) + } + if _, err := tx.Exec(ctx, string(migration)); err != nil { + _ = tx.Rollback(ctx) t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err) } + if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations(version) VALUES($1)", version); err != nil { + _ = tx.Rollback(ctx) + t.Fatalf("record migration %s: %v", filepath.Base(migrationPath), err) + } + if err := tx.Commit(ctx); err != nil { + t.Fatalf("commit migration %s: %v", filepath.Base(migrationPath), err) + } } } diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index c626ca7..13e6b75 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -160,6 +160,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut normalizedModelType := modelType attemptNo := task.AttemptCount var firstPreprocessing parameterPreprocessingLog + var walletReservations []store.WalletBillingReservation + walletReservationFinalized := false + defer func() { + if !walletReservationFinalized && len(walletReservations) > 0 { + _ = s.store.ReleaseTaskBillingReservations(context.WithoutCancel(ctx), walletReservations, "task_not_settled") + } + }() if len(candidates) > 0 { preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0]) firstCandidateBody = preprocessing.Body @@ -191,15 +198,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut return Result{}, err } estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0]) - if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil { - if errors.Is(err, store.ErrInsufficientWalletBalance) { + var reserveErr error + walletReservations, reserveErr = s.store.ReserveTaskBilling(ctx, task, user, estimatedBillings) + if reserveErr != nil { + if errors.Is(reserveErr, store.ErrInsufficientWalletBalance) { attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{ Task: task, Body: firstCandidateBody, Candidate: &candidates[0], AttemptNo: attemptNo + 1, Code: "insufficient_balance", - Cause: err, + Cause: reserveErr, Simulated: task.RunMode == "simulation", Scope: "wallet_balance", Reason: "wallet_balance_check_failed", @@ -207,13 +216,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut Preprocessing: &firstPreprocessing, ModelType: normalizedModelType, }) - failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing)) + failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", reserveErr.Error(), task.RunMode == "simulation", reserveErr, parameterPreprocessingMetrics(firstPreprocessing)) if finishErr != nil { return Result{}, finishErr } - return Result{Task: failed, Output: failed.Result}, err + return Result{Task: failed, Output: failed.Result}, reserveErr } - return Result{}, err + return Result{}, reserveErr } } if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil { @@ -286,9 +295,18 @@ candidatesLoop: if finishErr != nil { return Result{}, finishErr } - if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil { - return Result{}, settleErr + if finished.FinalChargeAmount > 0 { + walletReservationFinalized = true + if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil { + return Result{}, settleErr + } + } else if len(walletReservations) > 0 { + if releaseErr := s.store.ReleaseTaskBillingReservations(ctx, walletReservations, "task_billing_zero"); releaseErr != nil { + return Result{}, releaseErr + } + walletReservationFinalized = true } + walletReservationFinalized = true if finished.FinalChargeAmount > 0 { if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{ "amount": finished.FinalChargeAmount, @@ -695,6 +713,11 @@ func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated boo if key == "" { key = strings.ToLower(strings.TrimSpace(candidate.Provider)) } + provider := strings.ToLower(strings.TrimSpace(candidate.Provider)) + baseURL := strings.ToLower(strings.TrimSpace(candidate.BaseURL)) + if key == "google-gemini" || provider == "gemini" || provider == "google-gemini" || provider == "gemini-openai" || strings.Contains(baseURL, "generativelanguage.googleapis.com") { + key = "gemini" + } if client, ok := s.clients[key]; ok { return client } diff --git a/apps/api/internal/runner/service_test.go b/apps/api/internal/runner/service_test.go new file mode 100644 index 0000000..b61e1b3 --- /dev/null +++ b/apps/api/internal/runner/service_test.go @@ -0,0 +1,34 @@ +package runner + +import ( + "context" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +type namedClient string + +func (namedClient) Run(context.Context, clients.Request) (clients.Response, error) { + return clients.Response{}, nil +} + +func TestClientForMapsGoogleGeminiSpecToGeminiClient(t *testing.T) { + service := &Service{clients: map[string]clients.Client{ + "gemini": namedClient("gemini"), + "openai": namedClient("openai"), + }} + + candidates := []store.RuntimeModelCandidate{ + {SpecType: "google-gemini"}, + {SpecType: "openai", Provider: "gemini-openai"}, + {SpecType: "openai", BaseURL: "https://generativelanguage.googleapis.com/v1beta/openai"}, + } + for _, candidate := range candidates { + client := service.clientFor(candidate, false) + if client != namedClient("gemini") { + t.Fatalf("Gemini candidate should use gemini client, candidate=%+v got %T %[2]v", candidate, client) + } + } +} diff --git a/apps/api/internal/runner/wallet.go b/apps/api/internal/runner/wallet.go deleted file mode 100644 index f8b5970..0000000 --- a/apps/api/internal/runner/wallet.go +++ /dev/null @@ -1,38 +0,0 @@ -package runner - -import ( - "context" - "fmt" - "strings" - - "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" - "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" -) - -func (s *Service) ensureWalletBalance(ctx context.Context, user *auth.User, billings []any) error { - amounts := map[string]float64{} - for _, raw := range billings { - line, _ := raw.(map[string]any) - if line == nil { - continue - } - currency := strings.TrimSpace(stringFromAny(line["currency"])) - if currency == "" { - currency = "resource" - } - amounts[currency] = roundPrice(amounts[currency] + floatFromAny(line["amount"])) - } - for currency, amount := range amounts { - if amount <= 0 { - continue - } - availability, err := s.store.WalletAvailability(ctx, user, currency, amount) - if err != nil { - return err - } - if !availability.Enough { - return fmt.Errorf("%w: required %.6f %s, available %.6f", store.ErrInsufficientWalletBalance, amount, currency, availability.AvailableAmount) - } - } - return nil -} diff --git a/apps/api/internal/runner/wallet_execute_test.go b/apps/api/internal/runner/wallet_execute_test.go new file mode 100644 index 0000000..d4d9460 --- /dev/null +++ b/apps/api/internal/runner/wallet_execute_test.go @@ -0,0 +1,202 @@ +package runner + +import ( + "context" + "errors" + "io" + "log/slog" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +type walletExecuteMockClient struct { + calls atomic.Int32 +} + +func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) { + client.calls.Add(1) + return clients.Response{ + Result: map[string]any{"mock": true}, + RequestID: "mock-wallet-execute", + }, nil +} + +func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) { + databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL")) + if databaseURL == "" { + t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test") + } + + ctx := context.Background() + db, err := store.Connect(ctx, databaseURL) + if err != nil { + t.Fatalf("connect store: %v", err) + } + t.Cleanup(db.Close) + + suffix := strconv.FormatInt(time.Now().UnixNano(), 10) + tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{ + TenantKey: "wallet-execute-" + suffix, + Name: "Wallet Execute Test " + suffix, + }) + if err != nil { + t.Fatalf("create tenant: %v", err) + } + t.Cleanup(func() { + _ = db.DeleteTenant(context.Background(), tenant.ID) + }) + + gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{ + UserKey: "wallet-execute-user-" + suffix, + Username: "wallet_execute_" + suffix, + GatewayTenantID: tenant.ID, + TenantKey: tenant.TenantKey, + Roles: []string{"user"}, + }) + if err != nil { + t.Fatalf("create gateway user: %v", err) + } + t.Cleanup(func() { + _ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID) + }) + + platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{ + Provider: "mock", + PlatformKey: "wallet-execute-mock-" + suffix, + Name: "Wallet Execute Mock " + suffix, + AuthType: "none", + Config: map[string]any{"specType": "mock"}, + Status: "enabled", + Priority: 1, + }) + if err != nil { + t.Fatalf("create mock platform: %v", err) + } + t.Cleanup(func() { + _ = db.DeletePlatform(context.Background(), platform.ID) + }) + + if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{ + PlatformID: platform.ID, + ModelName: "mock-wallet-image", + ProviderModelName: "mock-wallet-image", + ModelType: store.StringList{"image_generate"}, + DisplayName: "Mock Wallet Image", + BillingConfig: map[string]any{ + "image": map[string]any{"basePrice": 10}, + }, + }); err != nil { + t.Fatalf("create mock platform model: %v", err) + } + + user := &auth.User{ + ID: gatewayUser.ID, + Source: "gateway", + GatewayUserID: gatewayUser.ID, + GatewayTenantID: tenant.ID, + TenantKey: tenant.TenantKey, + Roles: gatewayUser.Roles, + } + if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{ + GatewayUserID: gatewayUser.ID, + Currency: "resource", + Balance: 10, + Reason: "seed wallet execute test", + }); err != nil { + t.Fatalf("seed wallet balance: %v", err) + } + + tasks := make([]store.GatewayTask, 0, 2) + for i := 0; i < 2; i++ { + task, err := db.CreateTask(ctx, store.CreateTaskInput{ + Kind: "images.generations", + Model: "mock-wallet-image", + Request: map[string]any{ + "count": 1, + "prompt": "wallet execute test", + }, + }, user) + if err != nil { + t.Fatalf("create task: %v", err) + } + tasks = append(tasks, task) + } + + mockClient := &walletExecuteMockClient{} + service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil))) + service.clients["mock"] = mockClient + + type executeResult struct { + result Result + err error + } + results := make(chan executeResult, len(tasks)) + var wg sync.WaitGroup + for _, task := range tasks { + task := task + wg.Add(1) + go func() { + defer wg.Done() + result, err := service.Execute(ctx, task, user) + results <- executeResult{result: result, err: err} + }() + } + wg.Wait() + close(results) + + successCount := 0 + insufficientCount := 0 + for item := range results { + if item.err == nil { + successCount++ + if item.result.Task.Status != "succeeded" { + t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status) + } + if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) { + t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount) + } + continue + } + if errors.Is(item.err, store.ErrInsufficientWalletBalance) { + insufficientCount++ + if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" { + t.Fatalf("insufficient execution task = %+v", item.result.Task) + } + continue + } + t.Fatalf("unexpected execute error: %v", item.err) + } + if successCount != 1 || insufficientCount != 1 { + t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount) + } + if got := mockClient.calls.Load(); got != 1 { + t.Fatalf("mock client calls = %d, want 1", got) + } + + summary, err := db.GetWalletSummary(ctx, user, "resource") + if err != nil { + t.Fatalf("get wallet summary: %v", err) + } + account := summary.PrimaryAccount + if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) { + t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent) + } +} + +func walletExecuteFloatNear(a float64, b float64) bool { + delta := a - b + if delta < 0 { + delta = -delta + } + return delta < 0.000001 +} diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index 866d029..16bcaad 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -3,13 +3,13 @@ package store import ( "context" "encoding/json" + "fmt" "strconv" "strings" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" ) type TaskListFilter struct { @@ -687,14 +687,15 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error { if currency == "" || currency == "mixed" { currency = "resource" } - metadata, _ := json.Marshal(map[string]any{ + metadataMap := map[string]any{ "taskId": task.ID, "kind": task.Kind, "model": task.Model, "resolvedModel": task.ResolvedModel, "billings": task.Billings, "billingSummary": task.BillingSummary, - }) + } + metadata, _ := json.Marshal(metadataMap) return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { if _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_accounts ( @@ -706,42 +707,85 @@ ON CONFLICT (gateway_user_id, currency) DO NOTHING`, return err } var exists bool + var accountID string + var balanceBefore float64 + var frozenBefore float64 + var gatewayTenantID string + if err := tx.QueryRow(ctx, ` +SELECT id::text, balance::float8, frozen_balance::float8, COALESCE(gateway_tenant_id::text, '') +FROM gateway_wallet_accounts +WHERE gateway_user_id = $1::uuid + AND currency = $2 +FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &frozenBefore, &gatewayTenantID); err != nil { + return err + } if err := tx.QueryRow(ctx, ` SELECT EXISTS ( SELECT 1 - FROM gateway_wallet_transactions t - JOIN gateway_wallet_accounts a ON a.id = t.account_id - WHERE a.gateway_user_id = $1::uuid - AND a.currency = $2 - AND t.idempotency_key = $3 -)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil { + FROM gateway_wallet_transactions + WHERE account_id = $1::uuid + AND idempotency_key = $2 +)`, accountID, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil { return err } if exists { return nil } - var accountID string - var balanceBefore float64 - var gatewayTenantID string - if err := tx.QueryRow(ctx, ` -SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '') -FROM gateway_wallet_accounts -WHERE gateway_user_id = $1::uuid - AND currency = $2 -FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil { + + amount := roundMoney(task.FinalChargeAmount) + reservationKey, reservedAmount, err := activeWalletReservation(ctx, tx, accountID, task.ID) + if err != nil { return err } - amount := roundMoney(task.FinalChargeAmount) + reservedAmount = roundMoney(reservedAmount) + spendableForTask := roundMoney(balanceBefore - frozenBefore + reservedAmount) + if spendableForTask+0.000001 < amount { + return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, currency, spendableForTask) + } + balanceAfter := roundMoney(balanceBefore - amount) + frozenAfter := roundMoney(frozenBefore - reservedAmount) + if frozenAfter < 0 { + frozenAfter = 0 + } if _, err := tx.Exec(ctx, ` UPDATE gateway_wallet_accounts SET balance = $2, total_spent = total_spent + $3, + frozen_balance = $4, updated_at = now() -WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil { +WHERE id = $1::uuid`, accountID, balanceAfter, amount, frozenAfter); err != nil { return err } - _, err := tx.Exec(ctx, ` + if reservedAmount > 0 { + releaseMetadata, _ := json.Marshal(map[string]any{ + "taskId": task.ID, + "reason": "task_billing_settled", + "reserved": reservedAmount, + "frozenBefore": roundMoney(frozenBefore), + "frozenAfter": frozenAfter, + }) + if _, err := tx.Exec(ctx, ` +INSERT INTO gateway_wallet_transactions ( + account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, + amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata +) +VALUES ( + $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release', + $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb +) +ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`, + accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, reservedAmount, roundMoney(balanceBefore), roundMoney(balanceBefore), billingReservationReleaseIdempotencyKey(reservationKey), task.ID, string(releaseMetadata)); err != nil { + return err + } + } + billingMetadata := mergeObjects(metadataMap, map[string]any{ + "reservedAmount": reservedAmount, + "frozenBefore": roundMoney(frozenBefore), + "frozenAfter": frozenAfter, + }) + metadata, _ = json.Marshal(billingMetadata) + if _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_transactions ( account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata @@ -750,11 +794,10 @@ VALUES ( $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing', $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb )`, - accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)) - if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" { - return nil + accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)); err != nil { + return err } - return err + return nil }) } diff --git a/apps/api/internal/store/wallet.go b/apps/api/internal/store/wallet.go index f5de0fe..6feb148 100644 --- a/apps/api/internal/store/wallet.go +++ b/apps/api/internal/store/wallet.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strconv" "strings" "time" @@ -92,6 +93,16 @@ type WalletAdjustmentResult struct { Transaction GatewayWalletTransaction `json:"transaction"` } +type WalletBillingReservation struct { + TaskID string `json:"taskId"` + AccountID string `json:"accountId"` + GatewayUserID string `json:"gatewayUserId"` + GatewayTenantID string `json:"gatewayTenantId,omitempty"` + Currency string `json:"currency"` + Amount float64 `json:"amount"` + IdempotencyKey string `json:"idempotencyKey"` +} + func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) { gatewayUserID := localGatewayUserID(user) if gatewayUserID == "" { @@ -115,6 +126,223 @@ func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currenc return result, nil } +func (s *Store) ReserveTaskBilling(ctx context.Context, task GatewayTask, user *auth.User, billings []any) ([]WalletBillingReservation, error) { + gatewayUserID := taskGatewayUserID(task, user) + if gatewayUserID == "" { + return nil, nil + } + taskID := strings.TrimSpace(task.ID) + if taskID == "" { + return nil, fmt.Errorf("task id is required for wallet reservation") + } + + amounts := walletBillingAmounts(billings) + if len(amounts) == 0 { + return nil, nil + } + + reservations := make([]WalletBillingReservation, 0, len(amounts)) + err := pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { + for currency, rawAmount := range amounts { + amount := roundMoney(rawAmount) + if amount <= 0 { + continue + } + + account, err := s.ensureWalletAccount(ctx, tx, gatewayUserID, currency) + if err != nil { + return err + } + + locked, err := lockWalletAccount(ctx, tx, account.ID) + if err != nil { + return err + } + activeKey, activeAmount, err := activeWalletReservation(ctx, tx, locked.ID, taskID) + if err != nil { + return err + } + if activeAmount > 0 { + reservation := WalletBillingReservation{ + TaskID: taskID, + AccountID: locked.ID, + GatewayUserID: gatewayUserID, + GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID), + Currency: locked.Currency, + Amount: activeAmount, + IdempotencyKey: activeKey, + } + reservations = append(reservations, reservation) + continue + } + + sequence, err := nextWalletReservationSequence(ctx, tx, locked.ID, taskID) + if err != nil { + return err + } + key := billingReservationIdempotencyKey(taskID, locked.Currency, sequence) + reservation := WalletBillingReservation{ + TaskID: taskID, + AccountID: locked.ID, + GatewayUserID: gatewayUserID, + GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID), + Currency: locked.Currency, + Amount: amount, + IdempotencyKey: key, + } + available := roundMoney(locked.Balance - locked.FrozenBalance) + if available+0.000001 < amount { + return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, locked.Currency, available) + } + + frozenAfter := roundMoney(locked.FrozenBalance + amount) + if _, err := tx.Exec(ctx, ` +UPDATE gateway_wallet_accounts +SET frozen_balance = $2, + updated_at = now() +WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil { + return err + } + metadata, _ := json.Marshal(map[string]any{ + "taskId": taskID, + "kind": task.Kind, + "model": task.Model, + "reserved": amount, + "balance": roundMoney(locked.Balance), + "frozenBefore": roundMoney(locked.FrozenBalance), + "frozenAfter": frozenAfter, + }) + if _, err := tx.Exec(ctx, ` +INSERT INTO gateway_wallet_transactions ( + account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, + amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata +) +VALUES ( + $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'reserve', + $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb +)`, + locked.ID, + firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID), + gatewayUserID, + amount, + roundMoney(locked.Balance), + roundMoney(locked.Balance), + key, + taskID, + string(metadata), + ); err != nil { + return err + } + reservations = append(reservations, reservation) + } + return nil + }) + if err != nil { + return nil, err + } + return reservations, err +} + +func (s *Store) ReleaseTaskBillingReservations(ctx context.Context, reservations []WalletBillingReservation, reason string) error { + if len(reservations) == 0 { + return nil + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "task_not_settled" + } + return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error { + for _, reservation := range reservations { + if reservation.Amount <= 0 || strings.TrimSpace(reservation.AccountID) == "" { + continue + } + reserveKey := strings.TrimSpace(reservation.IdempotencyKey) + if reserveKey == "" { + reserveKey = billingReservationIdempotencyKey(reservation.TaskID, reservation.Currency, 1) + } + releaseKey := billingReservationReleaseIdempotencyKey(reserveKey) + locked, err := lockWalletAccount(ctx, tx, reservation.AccountID) + if err != nil { + if err == pgx.ErrNoRows { + continue + } + return err + } + var alreadyReleased bool + if err := tx.QueryRow(ctx, ` +SELECT EXISTS ( + SELECT 1 + FROM gateway_wallet_transactions + WHERE account_id = $1::uuid + AND idempotency_key = $2 +)`, reservation.AccountID, releaseKey).Scan(&alreadyReleased); err != nil { + return err + } + if alreadyReleased { + continue + } + var storedReservedAmount float64 + if err := tx.QueryRow(ctx, ` +SELECT COALESCE(( + SELECT amount::float8 + FROM gateway_wallet_transactions + WHERE account_id = $1::uuid + AND idempotency_key = $2 + AND transaction_type = 'reserve' + LIMIT 1 +), 0)::float8`, reservation.AccountID, reserveKey).Scan(&storedReservedAmount); err != nil { + return err + } + if storedReservedAmount <= 0 { + continue + } + + amount := roundMoney(storedReservedAmount) + frozenAfter := roundMoney(locked.FrozenBalance - amount) + if frozenAfter < 0 { + frozenAfter = 0 + } + if _, err := tx.Exec(ctx, ` +UPDATE gateway_wallet_accounts +SET frozen_balance = $2, + updated_at = now() +WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil { + return err + } + metadata, _ := json.Marshal(map[string]any{ + "taskId": reservation.TaskID, + "reason": reason, + "reserved": amount, + "frozenBefore": roundMoney(locked.FrozenBalance), + "frozenAfter": frozenAfter, + }) + if _, err := tx.Exec(ctx, ` +INSERT INTO gateway_wallet_transactions ( + account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type, + amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata +) +VALUES ( + $1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release', + $4, $5, $6, $7, 'gateway_task', $8, $9::jsonb +) +ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`, + locked.ID, + locked.GatewayTenantID, + locked.GatewayUserID, + amount, + roundMoney(locked.Balance), + roundMoney(locked.Balance), + releaseKey, + reservation.TaskID, + string(metadata), + ); err != nil { + return err + } + } + return nil + }) +} + func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) { gatewayUserID := localGatewayUserID(user) if gatewayUserID == "" { @@ -465,6 +693,124 @@ WHERE gateway_user_id = $1::uuid return account, nil } +func lockWalletAccount(ctx context.Context, tx pgx.Tx, accountID string) (GatewayWalletAccount, error) { + return scanWalletAccount(tx.QueryRow(ctx, ` +SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text, + COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''), + currency, balance::float8, frozen_balance::float8, total_recharged::float8, + total_spent::float8, status, metadata, created_at, updated_at +FROM gateway_wallet_accounts +WHERE id = $1::uuid +FOR UPDATE`, accountID)) +} + +func activeWalletReservation(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (string, float64, error) { + var key string + var amount float64 + err := tx.QueryRow(ctx, ` +SELECT COALESCE(t.idempotency_key, ''), t.amount::float8 +FROM gateway_wallet_transactions t +WHERE t.account_id = $1::uuid + AND t.reference_type = 'gateway_task' + AND t.reference_id = $2 + AND t.transaction_type = 'reserve' + AND COALESCE(t.idempotency_key, '') <> '' + AND NOT EXISTS ( + SELECT 1 + FROM gateway_wallet_transactions r + WHERE r.account_id = t.account_id + AND r.transaction_type = 'release' + AND r.idempotency_key = t.idempotency_key || ':release' + ) +ORDER BY t.created_at DESC +LIMIT 1`, accountID, taskID).Scan(&key, &amount) + if err == pgx.ErrNoRows { + return "", 0, nil + } + if err != nil { + return "", 0, err + } + return key, roundMoney(amount), nil +} + +func nextWalletReservationSequence(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (int, error) { + var count int + if err := tx.QueryRow(ctx, ` +SELECT COUNT(*)::int +FROM gateway_wallet_transactions +WHERE account_id = $1::uuid + AND reference_type = 'gateway_task' + AND reference_id = $2 + AND transaction_type = 'reserve'`, accountID, taskID).Scan(&count); err != nil { + return 0, err + } + return count + 1, nil +} + +func walletBillingAmounts(billings []any) map[string]float64 { + amounts := map[string]float64{} + for _, raw := range billings { + line, _ := raw.(map[string]any) + if line == nil { + continue + } + amount := roundMoney(walletFloat(line["amount"])) + if amount <= 0 { + continue + } + currency := normalizeWalletCurrency(walletString(line["currency"])) + amounts[currency] = roundMoney(amounts[currency] + amount) + } + return amounts +} + +func taskGatewayUserID(task GatewayTask, user *auth.User) string { + return firstNonEmpty(strings.TrimSpace(task.GatewayUserID), localGatewayUserID(user)) +} + +func billingReservationIdempotencyKey(taskID string, currency string, sequence int) string { + if sequence <= 0 { + sequence = 1 + } + return "task:" + strings.TrimSpace(taskID) + ":wallet-reservation:" + normalizeWalletCurrency(currency) + ":" + strconv.Itoa(sequence) +} + +func billingReservationReleaseIdempotencyKey(reservationKey string) string { + return strings.TrimSpace(reservationKey) + ":release" +} + +func walletString(value any) string { + if text, ok := value.(string); ok { + return strings.TrimSpace(text) + } + return "" +} + +func walletFloat(value any) float64 { + switch typed := value.(type) { + case float64: + return typed + case float32: + return float64(typed) + case int: + return float64(typed) + case int64: + return float64(typed) + case json.Number: + next, _ := typed.Float64() + return next + case string: + next := strings.TrimSpace(typed) + if next == "" { + return 0 + } + parsed, _ := strconv.ParseFloat(next, 64) + return parsed + default: + return 0 + } +} + func normalizeWalletCurrency(currency string) string { currency = strings.TrimSpace(currency) if currency == "" { diff --git a/apps/api/internal/store/wallet_reservation_test.go b/apps/api/internal/store/wallet_reservation_test.go new file mode 100644 index 0000000..79497ef --- /dev/null +++ b/apps/api/internal/store/wallet_reservation_test.go @@ -0,0 +1,171 @@ +package store + +import ( + "context" + "errors" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" +) + +func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) { + databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL")) + if databaseURL == "" { + t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test") + } + + ctx := context.Background() + db, err := Connect(ctx, databaseURL) + if err != nil { + t.Fatalf("connect store: %v", err) + } + defer db.Close() + + tenantID, userID := seedWalletReservationUser(t, ctx, db) + if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{ + GatewayUserID: userID, + Currency: "resource", + Balance: 10, + Reason: "seed wallet reservation test", + }); err != nil { + t.Fatalf("seed wallet balance: %v", err) + } + + firstTaskID := newWalletReservationTestUUID(t, ctx, db) + secondTaskID := newWalletReservationTestUUID(t, ctx, db) + billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}} + user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID} + tasks := []GatewayTask{ + {ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"}, + {ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"}, + } + + type reserveResult struct { + reservations []WalletBillingReservation + err error + } + results := make(chan reserveResult, len(tasks)) + var wg sync.WaitGroup + for _, task := range tasks { + task := task + wg.Add(1) + go func() { + defer wg.Done() + reservations, err := db.ReserveTaskBilling(ctx, task, user, billings) + results <- reserveResult{reservations: reservations, err: err} + }() + } + wg.Wait() + close(results) + + var successReservations []WalletBillingReservation + successCount := 0 + insufficientCount := 0 + for result := range results { + if result.err == nil { + successCount++ + successReservations = result.reservations + continue + } + if errors.Is(result.err, ErrInsufficientWalletBalance) { + insufficientCount++ + continue + } + t.Fatalf("unexpected reservation error: %v", result.err) + } + if successCount != 1 || insufficientCount != 1 { + t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount) + } + if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) { + t.Fatalf("unexpected successful reservations: %+v", successReservations) + } + + balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID) + if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) { + t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent) + } + + settleTask := GatewayTask{ + ID: successReservations[0].TaskID, + GatewayUserID: userID, + GatewayTenantID: tenantID, + Kind: "images.generations", + Model: "mock-image", + ResolvedModel: "mock-image", + Billings: billings, + BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)}, + FinalChargeAmount: 10, + } + if err := db.SettleTaskBilling(ctx, settleTask); err != nil { + t.Fatalf("settle reserved task billing: %v", err) + } + if err := db.SettleTaskBilling(ctx, settleTask); err != nil { + t.Fatalf("settle reserved task billing should be idempotent: %v", err) + } + balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID) + if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) { + t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent) + } +} + +func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) { + t.Helper() + suffix := strconv.FormatInt(time.Now().UnixNano(), 10) + var tenantID string + if err := db.pool.QueryRow(ctx, ` +INSERT INTO gateway_tenants (tenant_key, name) +VALUES ($1, $2) +RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil { + t.Fatalf("insert test tenant: %v", err) + } + var userID string + if err := db.pool.QueryRow(ctx, ` +INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles) +VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb) +RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil { + t.Fatalf("insert test user: %v", err) + } + t.Cleanup(func() { + cleanupCtx := context.Background() + _, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID) + _, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID) + }) + return tenantID, userID +} + +func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string { + t.Helper() + var id string + if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil { + t.Fatalf("generate uuid: %v", err) + } + return id +} + +func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) { + t.Helper() + var balance float64 + var frozen float64 + var spent float64 + if err := db.pool.QueryRow(ctx, ` +SELECT balance::float8, frozen_balance::float8, total_spent::float8 +FROM gateway_wallet_accounts +WHERE gateway_user_id = $1::uuid + AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil { + t.Fatalf("read wallet account: %v", err) + } + return balance, frozen, spent +} + +func walletFloatNear(a float64, b float64) bool { + delta := a - b + if delta < 0 { + delta = -delta + } + return delta < 0.000001 +} diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index 22abd77..a9fdcea 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -684,11 +684,13 @@ export interface VideoGenerationContent { }; video_url?: { url: string; + mime_type?: string; refer_type?: 'feature' | 'base'; keep_original_sound?: 'yes' | 'no'; }; audio_url?: { url: string; + mime_type?: string; }; role?: VideoGenerationContentRole; shot_index?: number; diff --git a/apps/web/src/pages/playground-upload.tsx b/apps/web/src/pages/playground-upload.tsx index a8d413d..09cde19 100644 --- a/apps/web/src/pages/playground-upload.tsx +++ b/apps/web/src/pages/playground-upload.tsx @@ -32,8 +32,8 @@ export interface PlaygroundUpload { export type OpenAIChatContentPart = | { type: 'text'; text: string } | { type: 'image_url'; image_url: { url: string } } - | { type: 'video_url'; video_url: { url: string } } - | { type: 'audio_url'; audio_url: { url: string } } + | { type: 'video_url'; video_url: { mime_type?: string; url: string } } + | { type: 'audio_url'; audio_url: { mime_type?: string; url: string } } | { type: 'file_url'; file_url: { filename: string; url: string } }; export const mediaUploadAccept = 'image/*,video/*,audio/*'; @@ -518,11 +518,17 @@ export function openAIContentFromPromptAndUploads(prompt: string, uploads: Playg function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined { if (!item.url) return undefined; if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } }; - if (item.kind === 'video') return { type: 'video_url', video_url: { url: item.url } }; - if (item.kind === 'audio') return { type: 'audio_url', audio_url: { url: item.url } }; + if (item.kind === 'video') return { type: 'video_url', video_url: mediaURLPayload(item) }; + if (item.kind === 'audio') return { type: 'audio_url', audio_url: mediaURLPayload(item) }; return { type: 'file_url', file_url: { filename: item.name, url: item.url } }; } +function mediaURLPayload(item: PlaygroundUpload) { + const payload: { mime_type?: string; url: string } = { url: item.url }; + if (item.contentType) payload.mime_type = item.contentType; + return payload; +} + export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude) { const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url); const payload: Record = {}; @@ -570,10 +576,10 @@ function videoGenerationContentFromUpload(item: PlaygroundUpload): VideoGenerati return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } }; } if (item.kind === 'video') { - return { type: 'video_url', role: 'reference_video', video_url: { url: item.url, refer_type: 'feature' } }; + return { type: 'video_url', role: 'reference_video', video_url: { ...mediaURLPayload(item), refer_type: 'feature' } }; } if (item.kind === 'audio') { - return { type: 'audio_url', role: 'reference_audio', audio_url: { url: item.url } }; + return { type: 'audio_url', role: 'reference_audio', audio_url: mediaURLPayload(item) }; } return undefined; }