feat(api): 添加多媒体内容支持并优化钱包计费系统
- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段 - 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输 - 添加 Gemini 客户端测试用例验证多媒体内容转换功能 - 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递 - 实现钱包计费预留机制,确保任务执行前余额充足 - 添加钱包冻结余额管理,防止并发操作导致的超扣问题 - 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还 - 优化数据库事务处理,确保计费操作的原子性和一致性 - 添加数据库集成测试验证迁移脚本执行流程 - 统一 Google Gemini 相关模型提供商标识符映射
This commit is contained in:
parent
af9b281d34
commit
8ad5b06c18
@ -569,6 +569,70 @@ func TestGeminiClientChatContract(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiClientChatConvertsMediaContentParts(t *testing.T) {
|
||||
var captured map[string]any
|
||||
var gotPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if err := json.NewDecoder(r.Body).Decode(&captured); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"candidates": []any{map[string]any{
|
||||
"content": map[string]any{"parts": []any{map[string]any{"text": "video ok"}}},
|
||||
}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||
Kind: "chat.completions",
|
||||
Model: "gemini:gemini-2.5-flash",
|
||||
Body: map[string]any{
|
||||
"model": "gemini:gemini-2.5-flash",
|
||||
"messages": []any{map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "analyze this video"},
|
||||
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://cdn.example.com/input.mov", "mime_type": "video/quicktime"}},
|
||||
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "data:audio/wav;base64,UklGRg=="}},
|
||||
},
|
||||
}},
|
||||
},
|
||||
Candidate: store.RuntimeModelCandidate{
|
||||
BaseURL: server.URL + "/v1beta/openai",
|
||||
ProviderModelName: "gemini-2.5-flash",
|
||||
ModelType: "chat",
|
||||
Credentials: map[string]any{"apiKey": "gemini-key"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("run gemini client: %v", err)
|
||||
}
|
||||
if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" {
|
||||
t.Fatalf("Gemini OpenAI-compatible base URL should normalize to native endpoint, got %s", gotPath)
|
||||
}
|
||||
contents, _ := captured["contents"].([]any)
|
||||
if len(contents) != 1 {
|
||||
t.Fatalf("unexpected Gemini contents: %+v", captured)
|
||||
}
|
||||
turn, _ := contents[0].(map[string]any)
|
||||
parts, _ := turn["parts"].([]any)
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected text, video, and audio parts, got %+v", turn)
|
||||
}
|
||||
video, _ := parts[1].(map[string]any)
|
||||
videoFile, _ := video["fileData"].(map[string]any)
|
||||
if videoFile["fileUri"] != "https://cdn.example.com/input.mov" || videoFile["mimeType"] != "video/quicktime" {
|
||||
t.Fatalf("video_url should become Gemini fileData, got %+v", video)
|
||||
}
|
||||
audio, _ := parts[2].(map[string]any)
|
||||
audioInline, _ := audio["inlineData"].(map[string]any)
|
||||
if audioInline["mimeType"] != "audio/wav" || audioInline["data"] != "UklGRg==" {
|
||||
t.Fatalf("audio data URL should become Gemini inlineData, got %+v", audio)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiClientChatRestoresToolContext(t *testing.T) {
|
||||
var captured map[string]any
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@ -5,8 +5,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@ -58,6 +60,7 @@ func geminiURL(baseURL string, model string, apiKey string) string {
|
||||
if base == "" {
|
||||
base = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
base = strings.TrimSuffix(base, "/openai")
|
||||
if strings.HasSuffix(base, "/v1beta") {
|
||||
base = strings.TrimSuffix(base, "/v1beta")
|
||||
}
|
||||
@ -121,7 +124,7 @@ func geminiContentsFromMessages(body map[string]any) []any {
|
||||
})
|
||||
continue
|
||||
}
|
||||
parts := geminiTextParts(message["content"])
|
||||
parts := geminiContentParts(message["content"])
|
||||
if role == "assistant" {
|
||||
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
|
||||
toolCall, _ := rawToolCall.(map[string]any)
|
||||
@ -157,7 +160,7 @@ func geminiRole(role string) string {
|
||||
return "user"
|
||||
}
|
||||
|
||||
func geminiTextParts(content any) []any {
|
||||
func geminiContentParts(content any) []any {
|
||||
parts := make([]any, 0)
|
||||
switch typed := content.(type) {
|
||||
case string:
|
||||
@ -167,14 +170,146 @@ func geminiTextParts(content any) []any {
|
||||
case []any:
|
||||
for _, rawPart := range typed {
|
||||
part, _ := rawPart.(map[string]any)
|
||||
if text := stringFromAny(firstPresent(part["text"], part["content"])); strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
if len(part) == 0 {
|
||||
continue
|
||||
}
|
||||
switch stringFromAny(part["type"]) {
|
||||
case "text":
|
||||
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
}
|
||||
case "image_url":
|
||||
if media := geminiMediaPart(part, "image_url", "image"); media != nil {
|
||||
parts = append(parts, media)
|
||||
}
|
||||
case "video_url":
|
||||
if media := geminiMediaPart(part, "video_url", "video"); media != nil {
|
||||
parts = append(parts, media)
|
||||
}
|
||||
case "audio_url":
|
||||
if media := geminiMediaPart(part, "audio_url", "audio"); media != nil {
|
||||
parts = append(parts, media)
|
||||
}
|
||||
case "input_audio":
|
||||
if media := geminiInputAudioPart(part); media != nil {
|
||||
parts = append(parts, media)
|
||||
}
|
||||
default:
|
||||
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func geminiMediaPart(part map[string]any, key string, mediaType string) map[string]any {
|
||||
nested := mapFromAny(part[key])
|
||||
uri := firstNonEmptyString(nested["url"], part["url"], part[key])
|
||||
if uri == "" {
|
||||
return nil
|
||||
}
|
||||
mimeType := firstNonEmptyString(nested["mime_type"], nested["mimeType"], part["mime_type"], part["mimeType"])
|
||||
return geminiMediaURLPart(uri, mimeType, mediaType)
|
||||
}
|
||||
|
||||
func geminiInputAudioPart(part map[string]any) map[string]any {
|
||||
audio := mapFromAny(part["input_audio"])
|
||||
uri := firstNonEmptyString(audio["data"], audio["url"])
|
||||
if uri == "" {
|
||||
return nil
|
||||
}
|
||||
mimeType := firstNonEmptyString(audio["mime_type"], audio["mimeType"])
|
||||
if mimeType == "" {
|
||||
format := strings.ToLower(strings.TrimPrefix(stringFromAny(audio["format"]), "."))
|
||||
if strings.Contains(format, "/") {
|
||||
mimeType = format
|
||||
} else if format == "mp3" {
|
||||
mimeType = "audio/mpeg"
|
||||
} else if format != "" {
|
||||
mimeType = "audio/" + format
|
||||
}
|
||||
}
|
||||
return geminiMediaURLPart(uri, mimeType, "audio")
|
||||
}
|
||||
|
||||
func geminiMediaURLPart(uri string, explicitMimeType string, mediaType string) map[string]any {
|
||||
if parsed := geminiDataURL(uri); parsed != nil {
|
||||
return map[string]any{"inlineData": map[string]any{
|
||||
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, parsed.mimeType), mediaType),
|
||||
"data": parsed.data,
|
||||
}}
|
||||
}
|
||||
return map[string]any{"fileData": map[string]any{
|
||||
"fileUri": uri,
|
||||
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, mimeFromURI(uri)), mediaType),
|
||||
}}
|
||||
}
|
||||
|
||||
type geminiParsedDataURL struct {
|
||||
mimeType string
|
||||
data string
|
||||
}
|
||||
|
||||
func geminiDataURL(value string) *geminiParsedDataURL {
|
||||
if !strings.HasPrefix(value, "data:") {
|
||||
return nil
|
||||
}
|
||||
prefix, data, ok := strings.Cut(value, ",")
|
||||
if !ok || !strings.Contains(prefix, ";base64") {
|
||||
return nil
|
||||
}
|
||||
mimeType := strings.TrimPrefix(strings.Split(prefix, ";")[0], "data:")
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
return &geminiParsedDataURL{mimeType: mimeType, data: data}
|
||||
}
|
||||
|
||||
func mimeFromURI(value string) string {
|
||||
pathValue := value
|
||||
if parsed, err := url.Parse(value); err == nil && parsed.Path != "" {
|
||||
pathValue = parsed.Path
|
||||
}
|
||||
extension := strings.ToLower(path.Ext(pathValue))
|
||||
if extension == "" {
|
||||
return ""
|
||||
}
|
||||
return mime.TypeByExtension(extension)
|
||||
}
|
||||
|
||||
func geminiMediaMime(mimeType string, mediaType string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0]))
|
||||
switch mediaType {
|
||||
case "image":
|
||||
if strings.HasPrefix(normalized, "image/") && normalized != "image/svg+xml" {
|
||||
return normalized
|
||||
}
|
||||
return "image/png"
|
||||
case "video":
|
||||
switch normalized {
|
||||
case "video/x-msvideo":
|
||||
return "video/avi"
|
||||
case "video/quicktime", "video/mpeg", "video/mp4", "video/avi", "video/x-flv", "video/mpg", "video/webm", "video/wmv", "video/3gpp":
|
||||
return normalized
|
||||
default:
|
||||
return "video/mp4"
|
||||
}
|
||||
case "audio":
|
||||
switch normalized {
|
||||
case "audio/x-wav", "audio/wave":
|
||||
return "audio/wav"
|
||||
case "audio/mpeg", "audio/mp3", "audio/wav", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", "audio/mp4", "audio/webm":
|
||||
return normalized
|
||||
default:
|
||||
return "audio/mpeg"
|
||||
}
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
func toolCallsSlice(value any) []any {
|
||||
switch typed := value.(type) {
|
||||
case []any:
|
||||
|
||||
@ -129,6 +129,13 @@ func TestCoreLocalFlow(t *testing.T) {
|
||||
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
|
||||
t.Fatalf("promote smoke user: %v", err)
|
||||
}
|
||||
doJSON(t, server.URL, http.MethodPost, "/api/v1/auth/login", "", map[string]any{
|
||||
"account": username,
|
||||
"password": password,
|
||||
}, http.StatusOK, &loginResponse)
|
||||
if loginResponse.AccessToken == "" {
|
||||
t.Fatal("admin login did not return access token")
|
||||
}
|
||||
var smokeGatewayUserID string
|
||||
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
|
||||
t.Fatalf("read smoke gateway user id: %v", err)
|
||||
@ -1402,14 +1409,41 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
|
||||
t.Fatalf("connect migration db: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
if _, err := pool.Exec(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version text PRIMARY KEY,
|
||||
applied_at timestamptz NOT NULL DEFAULT now()
|
||||
);`); err != nil {
|
||||
t.Fatalf("ensure schema migrations: %v", err)
|
||||
}
|
||||
for _, migrationPath := range migrationFiles {
|
||||
version := strings.TrimSuffix(filepath.Base(migrationPath), filepath.Ext(migrationPath))
|
||||
var exists bool
|
||||
if err := pool.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&exists); err != nil {
|
||||
t.Fatalf("check migration %s: %v", filepath.Base(migrationPath), err)
|
||||
}
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
migration, err := os.ReadFile(migrationPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err)
|
||||
}
|
||||
if _, err := pool.Exec(ctx, string(migration)); err != nil {
|
||||
tx, err := pool.Begin(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("begin migration %s: %v", filepath.Base(migrationPath), err)
|
||||
}
|
||||
if _, err := tx.Exec(ctx, string(migration)); err != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err)
|
||||
}
|
||||
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations(version) VALUES($1)", version); err != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
t.Fatalf("record migration %s: %v", filepath.Base(migrationPath), err)
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
t.Fatalf("commit migration %s: %v", filepath.Base(migrationPath), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -160,6 +160,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
normalizedModelType := modelType
|
||||
attemptNo := task.AttemptCount
|
||||
var firstPreprocessing parameterPreprocessingLog
|
||||
var walletReservations []store.WalletBillingReservation
|
||||
walletReservationFinalized := false
|
||||
defer func() {
|
||||
if !walletReservationFinalized && len(walletReservations) > 0 {
|
||||
_ = s.store.ReleaseTaskBillingReservations(context.WithoutCancel(ctx), walletReservations, "task_not_settled")
|
||||
}
|
||||
}()
|
||||
if len(candidates) > 0 {
|
||||
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0])
|
||||
firstCandidateBody = preprocessing.Body
|
||||
@ -191,15 +198,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
return Result{}, err
|
||||
}
|
||||
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
||||
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
|
||||
if errors.Is(err, store.ErrInsufficientWalletBalance) {
|
||||
var reserveErr error
|
||||
walletReservations, reserveErr = s.store.ReserveTaskBilling(ctx, task, user, estimatedBillings)
|
||||
if reserveErr != nil {
|
||||
if errors.Is(reserveErr, store.ErrInsufficientWalletBalance) {
|
||||
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||
Task: task,
|
||||
Body: firstCandidateBody,
|
||||
Candidate: &candidates[0],
|
||||
AttemptNo: attemptNo + 1,
|
||||
Code: "insufficient_balance",
|
||||
Cause: err,
|
||||
Cause: reserveErr,
|
||||
Simulated: task.RunMode == "simulation",
|
||||
Scope: "wallet_balance",
|
||||
Reason: "wallet_balance_check_failed",
|
||||
@ -207,13 +216,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
||||
Preprocessing: &firstPreprocessing,
|
||||
ModelType: normalizedModelType,
|
||||
})
|
||||
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing))
|
||||
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", reserveErr.Error(), task.RunMode == "simulation", reserveErr, parameterPreprocessingMetrics(firstPreprocessing))
|
||||
if finishErr != nil {
|
||||
return Result{}, finishErr
|
||||
}
|
||||
return Result{Task: failed, Output: failed.Result}, err
|
||||
return Result{Task: failed, Output: failed.Result}, reserveErr
|
||||
}
|
||||
return Result{}, err
|
||||
return Result{}, reserveErr
|
||||
}
|
||||
}
|
||||
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil {
|
||||
@ -286,9 +295,18 @@ candidatesLoop:
|
||||
if finishErr != nil {
|
||||
return Result{}, finishErr
|
||||
}
|
||||
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
||||
return Result{}, settleErr
|
||||
if finished.FinalChargeAmount > 0 {
|
||||
walletReservationFinalized = true
|
||||
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
||||
return Result{}, settleErr
|
||||
}
|
||||
} else if len(walletReservations) > 0 {
|
||||
if releaseErr := s.store.ReleaseTaskBillingReservations(ctx, walletReservations, "task_billing_zero"); releaseErr != nil {
|
||||
return Result{}, releaseErr
|
||||
}
|
||||
walletReservationFinalized = true
|
||||
}
|
||||
walletReservationFinalized = true
|
||||
if finished.FinalChargeAmount > 0 {
|
||||
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
||||
"amount": finished.FinalChargeAmount,
|
||||
@ -695,6 +713,11 @@ func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated boo
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(candidate.Provider))
|
||||
baseURL := strings.ToLower(strings.TrimSpace(candidate.BaseURL))
|
||||
if key == "google-gemini" || provider == "gemini" || provider == "google-gemini" || provider == "gemini-openai" || strings.Contains(baseURL, "generativelanguage.googleapis.com") {
|
||||
key = "gemini"
|
||||
}
|
||||
if client, ok := s.clients[key]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
34
apps/api/internal/runner/service_test.go
Normal file
34
apps/api/internal/runner/service_test.go
Normal file
@ -0,0 +1,34 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
type namedClient string
|
||||
|
||||
func (namedClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
||||
return clients.Response{}, nil
|
||||
}
|
||||
|
||||
func TestClientForMapsGoogleGeminiSpecToGeminiClient(t *testing.T) {
|
||||
service := &Service{clients: map[string]clients.Client{
|
||||
"gemini": namedClient("gemini"),
|
||||
"openai": namedClient("openai"),
|
||||
}}
|
||||
|
||||
candidates := []store.RuntimeModelCandidate{
|
||||
{SpecType: "google-gemini"},
|
||||
{SpecType: "openai", Provider: "gemini-openai"},
|
||||
{SpecType: "openai", BaseURL: "https://generativelanguage.googleapis.com/v1beta/openai"},
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
client := service.clientFor(candidate, false)
|
||||
if client != namedClient("gemini") {
|
||||
t.Fatalf("Gemini candidate should use gemini client, candidate=%+v got %T %[2]v", candidate, client)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func (s *Service) ensureWalletBalance(ctx context.Context, user *auth.User, billings []any) error {
|
||||
amounts := map[string]float64{}
|
||||
for _, raw := range billings {
|
||||
line, _ := raw.(map[string]any)
|
||||
if line == nil {
|
||||
continue
|
||||
}
|
||||
currency := strings.TrimSpace(stringFromAny(line["currency"]))
|
||||
if currency == "" {
|
||||
currency = "resource"
|
||||
}
|
||||
amounts[currency] = roundPrice(amounts[currency] + floatFromAny(line["amount"]))
|
||||
}
|
||||
for currency, amount := range amounts {
|
||||
if amount <= 0 {
|
||||
continue
|
||||
}
|
||||
availability, err := s.store.WalletAvailability(ctx, user, currency, amount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !availability.Enough {
|
||||
return fmt.Errorf("%w: required %.6f %s, available %.6f", store.ErrInsufficientWalletBalance, amount, currency, availability.AvailableAmount)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
202
apps/api/internal/runner/wallet_execute_test.go
Normal file
202
apps/api/internal/runner/wallet_execute_test.go
Normal file
@ -0,0 +1,202 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
type walletExecuteMockClient struct {
|
||||
calls atomic.Int32
|
||||
}
|
||||
|
||||
func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
||||
client.calls.Add(1)
|
||||
return clients.Response{
|
||||
Result: map[string]any{"mock": true},
|
||||
RequestID: "mock-wallet-execute",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) {
|
||||
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
||||
if databaseURL == "" {
|
||||
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
db, err := store.Connect(ctx, databaseURL)
|
||||
if err != nil {
|
||||
t.Fatalf("connect store: %v", err)
|
||||
}
|
||||
t.Cleanup(db.Close)
|
||||
|
||||
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{
|
||||
TenantKey: "wallet-execute-" + suffix,
|
||||
Name: "Wallet Execute Test " + suffix,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create tenant: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = db.DeleteTenant(context.Background(), tenant.ID)
|
||||
})
|
||||
|
||||
gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{
|
||||
UserKey: "wallet-execute-user-" + suffix,
|
||||
Username: "wallet_execute_" + suffix,
|
||||
GatewayTenantID: tenant.ID,
|
||||
TenantKey: tenant.TenantKey,
|
||||
Roles: []string{"user"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create gateway user: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID)
|
||||
})
|
||||
|
||||
platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{
|
||||
Provider: "mock",
|
||||
PlatformKey: "wallet-execute-mock-" + suffix,
|
||||
Name: "Wallet Execute Mock " + suffix,
|
||||
AuthType: "none",
|
||||
Config: map[string]any{"specType": "mock"},
|
||||
Status: "enabled",
|
||||
Priority: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create mock platform: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = db.DeletePlatform(context.Background(), platform.ID)
|
||||
})
|
||||
|
||||
if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{
|
||||
PlatformID: platform.ID,
|
||||
ModelName: "mock-wallet-image",
|
||||
ProviderModelName: "mock-wallet-image",
|
||||
ModelType: store.StringList{"image_generate"},
|
||||
DisplayName: "Mock Wallet Image",
|
||||
BillingConfig: map[string]any{
|
||||
"image": map[string]any{"basePrice": 10},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("create mock platform model: %v", err)
|
||||
}
|
||||
|
||||
user := &auth.User{
|
||||
ID: gatewayUser.ID,
|
||||
Source: "gateway",
|
||||
GatewayUserID: gatewayUser.ID,
|
||||
GatewayTenantID: tenant.ID,
|
||||
TenantKey: tenant.TenantKey,
|
||||
Roles: gatewayUser.Roles,
|
||||
}
|
||||
if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{
|
||||
GatewayUserID: gatewayUser.ID,
|
||||
Currency: "resource",
|
||||
Balance: 10,
|
||||
Reason: "seed wallet execute test",
|
||||
}); err != nil {
|
||||
t.Fatalf("seed wallet balance: %v", err)
|
||||
}
|
||||
|
||||
tasks := make([]store.GatewayTask, 0, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
task, err := db.CreateTask(ctx, store.CreateTaskInput{
|
||||
Kind: "images.generations",
|
||||
Model: "mock-wallet-image",
|
||||
Request: map[string]any{
|
||||
"count": 1,
|
||||
"prompt": "wallet execute test",
|
||||
},
|
||||
}, user)
|
||||
if err != nil {
|
||||
t.Fatalf("create task: %v", err)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
mockClient := &walletExecuteMockClient{}
|
||||
service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
service.clients["mock"] = mockClient
|
||||
|
||||
type executeResult struct {
|
||||
result Result
|
||||
err error
|
||||
}
|
||||
results := make(chan executeResult, len(tasks))
|
||||
var wg sync.WaitGroup
|
||||
for _, task := range tasks {
|
||||
task := task
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result, err := service.Execute(ctx, task, user)
|
||||
results <- executeResult{result: result, err: err}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
insufficientCount := 0
|
||||
for item := range results {
|
||||
if item.err == nil {
|
||||
successCount++
|
||||
if item.result.Task.Status != "succeeded" {
|
||||
t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status)
|
||||
}
|
||||
if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) {
|
||||
t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if errors.Is(item.err, store.ErrInsufficientWalletBalance) {
|
||||
insufficientCount++
|
||||
if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" {
|
||||
t.Fatalf("insufficient execution task = %+v", item.result.Task)
|
||||
}
|
||||
continue
|
||||
}
|
||||
t.Fatalf("unexpected execute error: %v", item.err)
|
||||
}
|
||||
if successCount != 1 || insufficientCount != 1 {
|
||||
t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
||||
}
|
||||
if got := mockClient.calls.Load(); got != 1 {
|
||||
t.Fatalf("mock client calls = %d, want 1", got)
|
||||
}
|
||||
|
||||
summary, err := db.GetWalletSummary(ctx, user, "resource")
|
||||
if err != nil {
|
||||
t.Fatalf("get wallet summary: %v", err)
|
||||
}
|
||||
account := summary.PrimaryAccount
|
||||
if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) {
|
||||
t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent)
|
||||
}
|
||||
}
|
||||
|
||||
func walletExecuteFloatNear(a float64, b float64) bool {
|
||||
delta := a - b
|
||||
if delta < 0 {
|
||||
delta = -delta
|
||||
}
|
||||
return delta < 0.000001
|
||||
}
|
||||
@ -3,13 +3,13 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type TaskListFilter struct {
|
||||
@ -687,14 +687,15 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
||||
if currency == "" || currency == "mixed" {
|
||||
currency = "resource"
|
||||
}
|
||||
metadata, _ := json.Marshal(map[string]any{
|
||||
metadataMap := map[string]any{
|
||||
"taskId": task.ID,
|
||||
"kind": task.Kind,
|
||||
"model": task.Model,
|
||||
"resolvedModel": task.ResolvedModel,
|
||||
"billings": task.Billings,
|
||||
"billingSummary": task.BillingSummary,
|
||||
})
|
||||
}
|
||||
metadata, _ := json.Marshal(metadataMap)
|
||||
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_accounts (
|
||||
@ -706,42 +707,85 @@ ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
|
||||
return err
|
||||
}
|
||||
var exists bool
|
||||
var accountID string
|
||||
var balanceBefore float64
|
||||
var frozenBefore float64
|
||||
var gatewayTenantID string
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT id::text, balance::float8, frozen_balance::float8, COALESCE(gateway_tenant_id::text, '')
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = $2
|
||||
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &frozenBefore, &gatewayTenantID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM gateway_wallet_transactions t
|
||||
JOIN gateway_wallet_accounts a ON a.id = t.account_id
|
||||
WHERE a.gateway_user_id = $1::uuid
|
||||
AND a.currency = $2
|
||||
AND t.idempotency_key = $3
|
||||
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
||||
FROM gateway_wallet_transactions
|
||||
WHERE account_id = $1::uuid
|
||||
AND idempotency_key = $2
|
||||
)`, accountID, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
var accountID string
|
||||
var balanceBefore float64
|
||||
var gatewayTenantID string
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = $2
|
||||
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
|
||||
|
||||
amount := roundMoney(task.FinalChargeAmount)
|
||||
reservationKey, reservedAmount, err := activeWalletReservation(ctx, tx, accountID, task.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
amount := roundMoney(task.FinalChargeAmount)
|
||||
reservedAmount = roundMoney(reservedAmount)
|
||||
spendableForTask := roundMoney(balanceBefore - frozenBefore + reservedAmount)
|
||||
if spendableForTask+0.000001 < amount {
|
||||
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, currency, spendableForTask)
|
||||
}
|
||||
|
||||
balanceAfter := roundMoney(balanceBefore - amount)
|
||||
frozenAfter := roundMoney(frozenBefore - reservedAmount)
|
||||
if frozenAfter < 0 {
|
||||
frozenAfter = 0
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_wallet_accounts
|
||||
SET balance = $2,
|
||||
total_spent = total_spent + $3,
|
||||
frozen_balance = $4,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
|
||||
WHERE id = $1::uuid`, accountID, balanceAfter, amount, frozenAfter); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := tx.Exec(ctx, `
|
||||
if reservedAmount > 0 {
|
||||
releaseMetadata, _ := json.Marshal(map[string]any{
|
||||
"taskId": task.ID,
|
||||
"reason": "task_billing_settled",
|
||||
"reserved": reservedAmount,
|
||||
"frozenBefore": roundMoney(frozenBefore),
|
||||
"frozenAfter": frozenAfter,
|
||||
})
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_transactions (
|
||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
|
||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||
)
|
||||
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
|
||||
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, reservedAmount, roundMoney(balanceBefore), roundMoney(balanceBefore), billingReservationReleaseIdempotencyKey(reservationKey), task.ID, string(releaseMetadata)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
billingMetadata := mergeObjects(metadataMap, map[string]any{
|
||||
"reservedAmount": reservedAmount,
|
||||
"frozenBefore": roundMoney(frozenBefore),
|
||||
"frozenAfter": frozenAfter,
|
||||
})
|
||||
metadata, _ = json.Marshal(billingMetadata)
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_transactions (
|
||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||
@ -750,11 +794,10 @@ VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||
)`,
|
||||
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
|
||||
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
|
||||
return nil
|
||||
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)); err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -92,6 +93,16 @@ type WalletAdjustmentResult struct {
|
||||
Transaction GatewayWalletTransaction `json:"transaction"`
|
||||
}
|
||||
|
||||
type WalletBillingReservation struct {
|
||||
TaskID string `json:"taskId"`
|
||||
AccountID string `json:"accountId"`
|
||||
GatewayUserID string `json:"gatewayUserId"`
|
||||
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
||||
Currency string `json:"currency"`
|
||||
Amount float64 `json:"amount"`
|
||||
IdempotencyKey string `json:"idempotencyKey"`
|
||||
}
|
||||
|
||||
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
|
||||
gatewayUserID := localGatewayUserID(user)
|
||||
if gatewayUserID == "" {
|
||||
@ -115,6 +126,223 @@ func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currenc
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Store) ReserveTaskBilling(ctx context.Context, task GatewayTask, user *auth.User, billings []any) ([]WalletBillingReservation, error) {
|
||||
gatewayUserID := taskGatewayUserID(task, user)
|
||||
if gatewayUserID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
taskID := strings.TrimSpace(task.ID)
|
||||
if taskID == "" {
|
||||
return nil, fmt.Errorf("task id is required for wallet reservation")
|
||||
}
|
||||
|
||||
amounts := walletBillingAmounts(billings)
|
||||
if len(amounts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
reservations := make([]WalletBillingReservation, 0, len(amounts))
|
||||
err := pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||
for currency, rawAmount := range amounts {
|
||||
amount := roundMoney(rawAmount)
|
||||
if amount <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
account, err := s.ensureWalletAccount(ctx, tx, gatewayUserID, currency)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
locked, err := lockWalletAccount(ctx, tx, account.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
activeKey, activeAmount, err := activeWalletReservation(ctx, tx, locked.ID, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if activeAmount > 0 {
|
||||
reservation := WalletBillingReservation{
|
||||
TaskID: taskID,
|
||||
AccountID: locked.ID,
|
||||
GatewayUserID: gatewayUserID,
|
||||
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||
Currency: locked.Currency,
|
||||
Amount: activeAmount,
|
||||
IdempotencyKey: activeKey,
|
||||
}
|
||||
reservations = append(reservations, reservation)
|
||||
continue
|
||||
}
|
||||
|
||||
sequence, err := nextWalletReservationSequence(ctx, tx, locked.ID, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := billingReservationIdempotencyKey(taskID, locked.Currency, sequence)
|
||||
reservation := WalletBillingReservation{
|
||||
TaskID: taskID,
|
||||
AccountID: locked.ID,
|
||||
GatewayUserID: gatewayUserID,
|
||||
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||
Currency: locked.Currency,
|
||||
Amount: amount,
|
||||
IdempotencyKey: key,
|
||||
}
|
||||
available := roundMoney(locked.Balance - locked.FrozenBalance)
|
||||
if available+0.000001 < amount {
|
||||
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, locked.Currency, available)
|
||||
}
|
||||
|
||||
frozenAfter := roundMoney(locked.FrozenBalance + amount)
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_wallet_accounts
|
||||
SET frozen_balance = $2,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||
return err
|
||||
}
|
||||
metadata, _ := json.Marshal(map[string]any{
|
||||
"taskId": taskID,
|
||||
"kind": task.Kind,
|
||||
"model": task.Model,
|
||||
"reserved": amount,
|
||||
"balance": roundMoney(locked.Balance),
|
||||
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||
"frozenAfter": frozenAfter,
|
||||
})
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_transactions (
|
||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'reserve',
|
||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||
)`,
|
||||
locked.ID,
|
||||
firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||
gatewayUserID,
|
||||
amount,
|
||||
roundMoney(locked.Balance),
|
||||
roundMoney(locked.Balance),
|
||||
key,
|
||||
taskID,
|
||||
string(metadata),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
reservations = append(reservations, reservation)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reservations, err
|
||||
}
|
||||
|
||||
func (s *Store) ReleaseTaskBillingReservations(ctx context.Context, reservations []WalletBillingReservation, reason string) error {
|
||||
if len(reservations) == 0 {
|
||||
return nil
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if reason == "" {
|
||||
reason = "task_not_settled"
|
||||
}
|
||||
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||
for _, reservation := range reservations {
|
||||
if reservation.Amount <= 0 || strings.TrimSpace(reservation.AccountID) == "" {
|
||||
continue
|
||||
}
|
||||
reserveKey := strings.TrimSpace(reservation.IdempotencyKey)
|
||||
if reserveKey == "" {
|
||||
reserveKey = billingReservationIdempotencyKey(reservation.TaskID, reservation.Currency, 1)
|
||||
}
|
||||
releaseKey := billingReservationReleaseIdempotencyKey(reserveKey)
|
||||
locked, err := lockWalletAccount(ctx, tx, reservation.AccountID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
var alreadyReleased bool
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM gateway_wallet_transactions
|
||||
WHERE account_id = $1::uuid
|
||||
AND idempotency_key = $2
|
||||
)`, reservation.AccountID, releaseKey).Scan(&alreadyReleased); err != nil {
|
||||
return err
|
||||
}
|
||||
if alreadyReleased {
|
||||
continue
|
||||
}
|
||||
var storedReservedAmount float64
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT COALESCE((
|
||||
SELECT amount::float8
|
||||
FROM gateway_wallet_transactions
|
||||
WHERE account_id = $1::uuid
|
||||
AND idempotency_key = $2
|
||||
AND transaction_type = 'reserve'
|
||||
LIMIT 1
|
||||
), 0)::float8`, reservation.AccountID, reserveKey).Scan(&storedReservedAmount); err != nil {
|
||||
return err
|
||||
}
|
||||
if storedReservedAmount <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
amount := roundMoney(storedReservedAmount)
|
||||
frozenAfter := roundMoney(locked.FrozenBalance - amount)
|
||||
if frozenAfter < 0 {
|
||||
frozenAfter = 0
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE gateway_wallet_accounts
|
||||
SET frozen_balance = $2,
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||
return err
|
||||
}
|
||||
metadata, _ := json.Marshal(map[string]any{
|
||||
"taskId": reservation.TaskID,
|
||||
"reason": reason,
|
||||
"reserved": amount,
|
||||
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||
"frozenAfter": frozenAfter,
|
||||
})
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_transactions (
|
||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||
)
|
||||
VALUES (
|
||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
|
||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||
)
|
||||
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
|
||||
locked.ID,
|
||||
locked.GatewayTenantID,
|
||||
locked.GatewayUserID,
|
||||
amount,
|
||||
roundMoney(locked.Balance),
|
||||
roundMoney(locked.Balance),
|
||||
releaseKey,
|
||||
reservation.TaskID,
|
||||
string(metadata),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
|
||||
gatewayUserID := localGatewayUserID(user)
|
||||
if gatewayUserID == "" {
|
||||
@ -465,6 +693,124 @@ WHERE gateway_user_id = $1::uuid
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func lockWalletAccount(ctx context.Context, tx pgx.Tx, accountID string) (GatewayWalletAccount, error) {
|
||||
return scanWalletAccount(tx.QueryRow(ctx, `
|
||||
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
|
||||
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
|
||||
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
|
||||
total_spent::float8, status, metadata, created_at, updated_at
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE id = $1::uuid
|
||||
FOR UPDATE`, accountID))
|
||||
}
|
||||
|
||||
func activeWalletReservation(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (string, float64, error) {
|
||||
var key string
|
||||
var amount float64
|
||||
err := tx.QueryRow(ctx, `
|
||||
SELECT COALESCE(t.idempotency_key, ''), t.amount::float8
|
||||
FROM gateway_wallet_transactions t
|
||||
WHERE t.account_id = $1::uuid
|
||||
AND t.reference_type = 'gateway_task'
|
||||
AND t.reference_id = $2
|
||||
AND t.transaction_type = 'reserve'
|
||||
AND COALESCE(t.idempotency_key, '') <> ''
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM gateway_wallet_transactions r
|
||||
WHERE r.account_id = t.account_id
|
||||
AND r.transaction_type = 'release'
|
||||
AND r.idempotency_key = t.idempotency_key || ':release'
|
||||
)
|
||||
ORDER BY t.created_at DESC
|
||||
LIMIT 1`, accountID, taskID).Scan(&key, &amount)
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return key, roundMoney(amount), nil
|
||||
}
|
||||
|
||||
func nextWalletReservationSequence(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (int, error) {
|
||||
var count int
|
||||
if err := tx.QueryRow(ctx, `
|
||||
SELECT COUNT(*)::int
|
||||
FROM gateway_wallet_transactions
|
||||
WHERE account_id = $1::uuid
|
||||
AND reference_type = 'gateway_task'
|
||||
AND reference_id = $2
|
||||
AND transaction_type = 'reserve'`, accountID, taskID).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count + 1, nil
|
||||
}
|
||||
|
||||
func walletBillingAmounts(billings []any) map[string]float64 {
|
||||
amounts := map[string]float64{}
|
||||
for _, raw := range billings {
|
||||
line, _ := raw.(map[string]any)
|
||||
if line == nil {
|
||||
continue
|
||||
}
|
||||
amount := roundMoney(walletFloat(line["amount"]))
|
||||
if amount <= 0 {
|
||||
continue
|
||||
}
|
||||
currency := normalizeWalletCurrency(walletString(line["currency"]))
|
||||
amounts[currency] = roundMoney(amounts[currency] + amount)
|
||||
}
|
||||
return amounts
|
||||
}
|
||||
|
||||
func taskGatewayUserID(task GatewayTask, user *auth.User) string {
|
||||
return firstNonEmpty(strings.TrimSpace(task.GatewayUserID), localGatewayUserID(user))
|
||||
}
|
||||
|
||||
func billingReservationIdempotencyKey(taskID string, currency string, sequence int) string {
|
||||
if sequence <= 0 {
|
||||
sequence = 1
|
||||
}
|
||||
return "task:" + strings.TrimSpace(taskID) + ":wallet-reservation:" + normalizeWalletCurrency(currency) + ":" + strconv.Itoa(sequence)
|
||||
}
|
||||
|
||||
func billingReservationReleaseIdempotencyKey(reservationKey string) string {
|
||||
return strings.TrimSpace(reservationKey) + ":release"
|
||||
}
|
||||
|
||||
func walletString(value any) string {
|
||||
if text, ok := value.(string); ok {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func walletFloat(value any) float64 {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed
|
||||
case float32:
|
||||
return float64(typed)
|
||||
case int:
|
||||
return float64(typed)
|
||||
case int64:
|
||||
return float64(typed)
|
||||
case json.Number:
|
||||
next, _ := typed.Float64()
|
||||
return next
|
||||
case string:
|
||||
next := strings.TrimSpace(typed)
|
||||
if next == "" {
|
||||
return 0
|
||||
}
|
||||
parsed, _ := strconv.ParseFloat(next, 64)
|
||||
return parsed
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeWalletCurrency(currency string) string {
|
||||
currency = strings.TrimSpace(currency)
|
||||
if currency == "" {
|
||||
|
||||
171
apps/api/internal/store/wallet_reservation_test.go
Normal file
171
apps/api/internal/store/wallet_reservation_test.go
Normal file
@ -0,0 +1,171 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
)
|
||||
|
||||
func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) {
|
||||
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
||||
if databaseURL == "" {
|
||||
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
db, err := Connect(ctx, databaseURL)
|
||||
if err != nil {
|
||||
t.Fatalf("connect store: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tenantID, userID := seedWalletReservationUser(t, ctx, db)
|
||||
if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{
|
||||
GatewayUserID: userID,
|
||||
Currency: "resource",
|
||||
Balance: 10,
|
||||
Reason: "seed wallet reservation test",
|
||||
}); err != nil {
|
||||
t.Fatalf("seed wallet balance: %v", err)
|
||||
}
|
||||
|
||||
firstTaskID := newWalletReservationTestUUID(t, ctx, db)
|
||||
secondTaskID := newWalletReservationTestUUID(t, ctx, db)
|
||||
billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}}
|
||||
user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID}
|
||||
tasks := []GatewayTask{
|
||||
{ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"},
|
||||
{ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"},
|
||||
}
|
||||
|
||||
type reserveResult struct {
|
||||
reservations []WalletBillingReservation
|
||||
err error
|
||||
}
|
||||
results := make(chan reserveResult, len(tasks))
|
||||
var wg sync.WaitGroup
|
||||
for _, task := range tasks {
|
||||
task := task
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reservations, err := db.ReserveTaskBilling(ctx, task, user, billings)
|
||||
results <- reserveResult{reservations: reservations, err: err}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
var successReservations []WalletBillingReservation
|
||||
successCount := 0
|
||||
insufficientCount := 0
|
||||
for result := range results {
|
||||
if result.err == nil {
|
||||
successCount++
|
||||
successReservations = result.reservations
|
||||
continue
|
||||
}
|
||||
if errors.Is(result.err, ErrInsufficientWalletBalance) {
|
||||
insufficientCount++
|
||||
continue
|
||||
}
|
||||
t.Fatalf("unexpected reservation error: %v", result.err)
|
||||
}
|
||||
if successCount != 1 || insufficientCount != 1 {
|
||||
t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
||||
}
|
||||
if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) {
|
||||
t.Fatalf("unexpected successful reservations: %+v", successReservations)
|
||||
}
|
||||
|
||||
balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID)
|
||||
if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) {
|
||||
t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
||||
}
|
||||
|
||||
settleTask := GatewayTask{
|
||||
ID: successReservations[0].TaskID,
|
||||
GatewayUserID: userID,
|
||||
GatewayTenantID: tenantID,
|
||||
Kind: "images.generations",
|
||||
Model: "mock-image",
|
||||
ResolvedModel: "mock-image",
|
||||
Billings: billings,
|
||||
BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)},
|
||||
FinalChargeAmount: 10,
|
||||
}
|
||||
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
||||
t.Fatalf("settle reserved task billing: %v", err)
|
||||
}
|
||||
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
||||
t.Fatalf("settle reserved task billing should be idempotent: %v", err)
|
||||
}
|
||||
balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID)
|
||||
if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) {
|
||||
t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
||||
}
|
||||
}
|
||||
|
||||
func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) {
|
||||
t.Helper()
|
||||
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
var tenantID string
|
||||
if err := db.pool.QueryRow(ctx, `
|
||||
INSERT INTO gateway_tenants (tenant_key, name)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil {
|
||||
t.Fatalf("insert test tenant: %v", err)
|
||||
}
|
||||
var userID string
|
||||
if err := db.pool.QueryRow(ctx, `
|
||||
INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles)
|
||||
VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb)
|
||||
RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil {
|
||||
t.Fatalf("insert test user: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cleanupCtx := context.Background()
|
||||
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID)
|
||||
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID)
|
||||
})
|
||||
return tenantID, userID
|
||||
}
|
||||
|
||||
func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string {
|
||||
t.Helper()
|
||||
var id string
|
||||
if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil {
|
||||
t.Fatalf("generate uuid: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) {
|
||||
t.Helper()
|
||||
var balance float64
|
||||
var frozen float64
|
||||
var spent float64
|
||||
if err := db.pool.QueryRow(ctx, `
|
||||
SELECT balance::float8, frozen_balance::float8, total_spent::float8
|
||||
FROM gateway_wallet_accounts
|
||||
WHERE gateway_user_id = $1::uuid
|
||||
AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil {
|
||||
t.Fatalf("read wallet account: %v", err)
|
||||
}
|
||||
return balance, frozen, spent
|
||||
}
|
||||
|
||||
func walletFloatNear(a float64, b float64) bool {
|
||||
delta := a - b
|
||||
if delta < 0 {
|
||||
delta = -delta
|
||||
}
|
||||
return delta < 0.000001
|
||||
}
|
||||
@ -684,11 +684,13 @@ export interface VideoGenerationContent {
|
||||
};
|
||||
video_url?: {
|
||||
url: string;
|
||||
mime_type?: string;
|
||||
refer_type?: 'feature' | 'base';
|
||||
keep_original_sound?: 'yes' | 'no';
|
||||
};
|
||||
audio_url?: {
|
||||
url: string;
|
||||
mime_type?: string;
|
||||
};
|
||||
role?: VideoGenerationContentRole;
|
||||
shot_index?: number;
|
||||
|
||||
@ -32,8 +32,8 @@ export interface PlaygroundUpload {
|
||||
export type OpenAIChatContentPart =
|
||||
| { type: 'text'; text: string }
|
||||
| { type: 'image_url'; image_url: { url: string } }
|
||||
| { type: 'video_url'; video_url: { url: string } }
|
||||
| { type: 'audio_url'; audio_url: { url: string } }
|
||||
| { type: 'video_url'; video_url: { mime_type?: string; url: string } }
|
||||
| { type: 'audio_url'; audio_url: { mime_type?: string; url: string } }
|
||||
| { type: 'file_url'; file_url: { filename: string; url: string } };
|
||||
|
||||
export const mediaUploadAccept = 'image/*,video/*,audio/*';
|
||||
@ -518,11 +518,17 @@ export function openAIContentFromPromptAndUploads(prompt: string, uploads: Playg
|
||||
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
|
||||
if (!item.url) return undefined;
|
||||
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
|
||||
if (item.kind === 'video') return { type: 'video_url', video_url: { url: item.url } };
|
||||
if (item.kind === 'audio') return { type: 'audio_url', audio_url: { url: item.url } };
|
||||
if (item.kind === 'video') return { type: 'video_url', video_url: mediaURLPayload(item) };
|
||||
if (item.kind === 'audio') return { type: 'audio_url', audio_url: mediaURLPayload(item) };
|
||||
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
|
||||
}
|
||||
|
||||
function mediaURLPayload(item: PlaygroundUpload) {
|
||||
const payload: { mime_type?: string; url: string } = { url: item.url };
|
||||
if (item.contentType) payload.mime_type = item.contentType;
|
||||
return payload;
|
||||
}
|
||||
|
||||
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
|
||||
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
|
||||
const payload: Record<string, string | string[]> = {};
|
||||
@ -570,10 +576,10 @@ function videoGenerationContentFromUpload(item: PlaygroundUpload): VideoGenerati
|
||||
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
|
||||
}
|
||||
if (item.kind === 'video') {
|
||||
return { type: 'video_url', role: 'reference_video', video_url: { url: item.url, refer_type: 'feature' } };
|
||||
return { type: 'video_url', role: 'reference_video', video_url: { ...mediaURLPayload(item), refer_type: 'feature' } };
|
||||
}
|
||||
if (item.kind === 'audio') {
|
||||
return { type: 'audio_url', role: 'reference_audio', audio_url: { url: item.url } };
|
||||
return { type: 'audio_url', role: 'reference_audio', audio_url: mediaURLPayload(item) };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user