feat(api): 添加多媒体内容支持并优化钱包计费系统
- 在 API 接口定义中为 video_url 和 audio_url 类型添加 mime_type 字段 - 实现 Google Gemini 客户端对视频和音频内容的支持,包括媒体类型检测和数据传输 - 添加 Gemini 客户端测试用例验证多媒体内容转换功能 - 重构 Playground 页面的媒体上传逻辑以支持 MIME 类型传递 - 实现钱包计费预留机制,确保任务执行前余额充足 - 添加钱包冻结余额管理,防止并发操作导致的超扣问题 - 实现计费预留释放逻辑,处理任务失败或取消情况下的资金返还 - 优化数据库事务处理,确保计费操作的原子性和一致性 - 添加数据库集成测试验证迁移脚本执行流程 - 统一 Google Gemini 相关模型提供商标识符映射
This commit is contained in:
parent
af9b281d34
commit
8ad5b06c18
@ -569,6 +569,70 @@ func TestGeminiClientChatContract(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGeminiClientChatConvertsMediaContentParts(t *testing.T) {
|
||||||
|
var captured map[string]any
|
||||||
|
var gotPath string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&captured); err != nil {
|
||||||
|
t.Fatalf("decode request: %v", err)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"candidates": []any{map[string]any{
|
||||||
|
"content": map[string]any{"parts": []any{map[string]any{"text": "video ok"}}},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := (GeminiClient{HTTPClient: server.Client()}).Run(context.Background(), Request{
|
||||||
|
Kind: "chat.completions",
|
||||||
|
Model: "gemini:gemini-2.5-flash",
|
||||||
|
Body: map[string]any{
|
||||||
|
"model": "gemini:gemini-2.5-flash",
|
||||||
|
"messages": []any{map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{"type": "text", "text": "analyze this video"},
|
||||||
|
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://cdn.example.com/input.mov", "mime_type": "video/quicktime"}},
|
||||||
|
map[string]any{"type": "audio_url", "audio_url": map[string]any{"url": "data:audio/wav;base64,UklGRg=="}},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
Candidate: store.RuntimeModelCandidate{
|
||||||
|
BaseURL: server.URL + "/v1beta/openai",
|
||||||
|
ProviderModelName: "gemini-2.5-flash",
|
||||||
|
ModelType: "chat",
|
||||||
|
Credentials: map[string]any{"apiKey": "gemini-key"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("run gemini client: %v", err)
|
||||||
|
}
|
||||||
|
if gotPath != "/v1beta/models/gemini-2.5-flash:generateContent" {
|
||||||
|
t.Fatalf("Gemini OpenAI-compatible base URL should normalize to native endpoint, got %s", gotPath)
|
||||||
|
}
|
||||||
|
contents, _ := captured["contents"].([]any)
|
||||||
|
if len(contents) != 1 {
|
||||||
|
t.Fatalf("unexpected Gemini contents: %+v", captured)
|
||||||
|
}
|
||||||
|
turn, _ := contents[0].(map[string]any)
|
||||||
|
parts, _ := turn["parts"].([]any)
|
||||||
|
if len(parts) != 3 {
|
||||||
|
t.Fatalf("expected text, video, and audio parts, got %+v", turn)
|
||||||
|
}
|
||||||
|
video, _ := parts[1].(map[string]any)
|
||||||
|
videoFile, _ := video["fileData"].(map[string]any)
|
||||||
|
if videoFile["fileUri"] != "https://cdn.example.com/input.mov" || videoFile["mimeType"] != "video/quicktime" {
|
||||||
|
t.Fatalf("video_url should become Gemini fileData, got %+v", video)
|
||||||
|
}
|
||||||
|
audio, _ := parts[2].(map[string]any)
|
||||||
|
audioInline, _ := audio["inlineData"].(map[string]any)
|
||||||
|
if audioInline["mimeType"] != "audio/wav" || audioInline["data"] != "UklGRg==" {
|
||||||
|
t.Fatalf("audio data URL should become Gemini inlineData, got %+v", audio)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGeminiClientChatRestoresToolContext(t *testing.T) {
|
func TestGeminiClientChatRestoresToolContext(t *testing.T) {
|
||||||
var captured map[string]any
|
var captured map[string]any
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@ -5,8 +5,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -58,6 +60,7 @@ func geminiURL(baseURL string, model string, apiKey string) string {
|
|||||||
if base == "" {
|
if base == "" {
|
||||||
base = "https://generativelanguage.googleapis.com"
|
base = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
|
base = strings.TrimSuffix(base, "/openai")
|
||||||
if strings.HasSuffix(base, "/v1beta") {
|
if strings.HasSuffix(base, "/v1beta") {
|
||||||
base = strings.TrimSuffix(base, "/v1beta")
|
base = strings.TrimSuffix(base, "/v1beta")
|
||||||
}
|
}
|
||||||
@ -121,7 +124,7 @@ func geminiContentsFromMessages(body map[string]any) []any {
|
|||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
parts := geminiTextParts(message["content"])
|
parts := geminiContentParts(message["content"])
|
||||||
if role == "assistant" {
|
if role == "assistant" {
|
||||||
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
|
for _, rawToolCall := range toolCallsSlice(message["tool_calls"]) {
|
||||||
toolCall, _ := rawToolCall.(map[string]any)
|
toolCall, _ := rawToolCall.(map[string]any)
|
||||||
@ -157,7 +160,7 @@ func geminiRole(role string) string {
|
|||||||
return "user"
|
return "user"
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiTextParts(content any) []any {
|
func geminiContentParts(content any) []any {
|
||||||
parts := make([]any, 0)
|
parts := make([]any, 0)
|
||||||
switch typed := content.(type) {
|
switch typed := content.(type) {
|
||||||
case string:
|
case string:
|
||||||
@ -167,14 +170,146 @@ func geminiTextParts(content any) []any {
|
|||||||
case []any:
|
case []any:
|
||||||
for _, rawPart := range typed {
|
for _, rawPart := range typed {
|
||||||
part, _ := rawPart.(map[string]any)
|
part, _ := rawPart.(map[string]any)
|
||||||
if text := stringFromAny(firstPresent(part["text"], part["content"])); strings.TrimSpace(text) != "" {
|
if len(part) == 0 {
|
||||||
parts = append(parts, map[string]any{"text": text})
|
continue
|
||||||
|
}
|
||||||
|
switch stringFromAny(part["type"]) {
|
||||||
|
case "text":
|
||||||
|
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
|
||||||
|
parts = append(parts, map[string]any{"text": text})
|
||||||
|
}
|
||||||
|
case "image_url":
|
||||||
|
if media := geminiMediaPart(part, "image_url", "image"); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
case "video_url":
|
||||||
|
if media := geminiMediaPart(part, "video_url", "video"); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
case "audio_url":
|
||||||
|
if media := geminiMediaPart(part, "audio_url", "audio"); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
case "input_audio":
|
||||||
|
if media := geminiInputAudioPart(part); media != nil {
|
||||||
|
parts = append(parts, media)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if text := strings.TrimSpace(stringFromAny(firstPresent(part["text"], part["content"]))); text != "" {
|
||||||
|
parts = append(parts, map[string]any{"text": text})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return parts
|
return parts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func geminiMediaPart(part map[string]any, key string, mediaType string) map[string]any {
|
||||||
|
nested := mapFromAny(part[key])
|
||||||
|
uri := firstNonEmptyString(nested["url"], part["url"], part[key])
|
||||||
|
if uri == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mimeType := firstNonEmptyString(nested["mime_type"], nested["mimeType"], part["mime_type"], part["mimeType"])
|
||||||
|
return geminiMediaURLPart(uri, mimeType, mediaType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiInputAudioPart(part map[string]any) map[string]any {
|
||||||
|
audio := mapFromAny(part["input_audio"])
|
||||||
|
uri := firstNonEmptyString(audio["data"], audio["url"])
|
||||||
|
if uri == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mimeType := firstNonEmptyString(audio["mime_type"], audio["mimeType"])
|
||||||
|
if mimeType == "" {
|
||||||
|
format := strings.ToLower(strings.TrimPrefix(stringFromAny(audio["format"]), "."))
|
||||||
|
if strings.Contains(format, "/") {
|
||||||
|
mimeType = format
|
||||||
|
} else if format == "mp3" {
|
||||||
|
mimeType = "audio/mpeg"
|
||||||
|
} else if format != "" {
|
||||||
|
mimeType = "audio/" + format
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return geminiMediaURLPart(uri, mimeType, "audio")
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiMediaURLPart(uri string, explicitMimeType string, mediaType string) map[string]any {
|
||||||
|
if parsed := geminiDataURL(uri); parsed != nil {
|
||||||
|
return map[string]any{"inlineData": map[string]any{
|
||||||
|
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, parsed.mimeType), mediaType),
|
||||||
|
"data": parsed.data,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
return map[string]any{"fileData": map[string]any{
|
||||||
|
"fileUri": uri,
|
||||||
|
"mimeType": geminiMediaMime(firstNonEmptyString(explicitMimeType, mimeFromURI(uri)), mediaType),
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type geminiParsedDataURL struct {
|
||||||
|
mimeType string
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiDataURL(value string) *geminiParsedDataURL {
|
||||||
|
if !strings.HasPrefix(value, "data:") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
prefix, data, ok := strings.Cut(value, ",")
|
||||||
|
if !ok || !strings.Contains(prefix, ";base64") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
mimeType := strings.TrimPrefix(strings.Split(prefix, ";")[0], "data:")
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "application/octet-stream"
|
||||||
|
}
|
||||||
|
return &geminiParsedDataURL{mimeType: mimeType, data: data}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mimeFromURI(value string) string {
|
||||||
|
pathValue := value
|
||||||
|
if parsed, err := url.Parse(value); err == nil && parsed.Path != "" {
|
||||||
|
pathValue = parsed.Path
|
||||||
|
}
|
||||||
|
extension := strings.ToLower(path.Ext(pathValue))
|
||||||
|
if extension == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return mime.TypeByExtension(extension)
|
||||||
|
}
|
||||||
|
|
||||||
|
func geminiMediaMime(mimeType string, mediaType string) string {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0]))
|
||||||
|
switch mediaType {
|
||||||
|
case "image":
|
||||||
|
if strings.HasPrefix(normalized, "image/") && normalized != "image/svg+xml" {
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
return "image/png"
|
||||||
|
case "video":
|
||||||
|
switch normalized {
|
||||||
|
case "video/x-msvideo":
|
||||||
|
return "video/avi"
|
||||||
|
case "video/quicktime", "video/mpeg", "video/mp4", "video/avi", "video/x-flv", "video/mpg", "video/webm", "video/wmv", "video/3gpp":
|
||||||
|
return normalized
|
||||||
|
default:
|
||||||
|
return "video/mp4"
|
||||||
|
}
|
||||||
|
case "audio":
|
||||||
|
switch normalized {
|
||||||
|
case "audio/x-wav", "audio/wave":
|
||||||
|
return "audio/wav"
|
||||||
|
case "audio/mpeg", "audio/mp3", "audio/wav", "audio/aiff", "audio/aac", "audio/ogg", "audio/flac", "audio/mp4", "audio/webm":
|
||||||
|
return normalized
|
||||||
|
default:
|
||||||
|
return "audio/mpeg"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "application/octet-stream"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func toolCallsSlice(value any) []any {
|
func toolCallsSlice(value any) []any {
|
||||||
switch typed := value.(type) {
|
switch typed := value.(type) {
|
||||||
case []any:
|
case []any:
|
||||||
|
|||||||
@ -129,6 +129,13 @@ func TestCoreLocalFlow(t *testing.T) {
|
|||||||
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
|
if _, err := testPool.Exec(ctx, `UPDATE gateway_users SET roles = '["admin"]'::jsonb WHERE username = $1`, username); err != nil {
|
||||||
t.Fatalf("promote smoke user: %v", err)
|
t.Fatalf("promote smoke user: %v", err)
|
||||||
}
|
}
|
||||||
|
doJSON(t, server.URL, http.MethodPost, "/api/v1/auth/login", "", map[string]any{
|
||||||
|
"account": username,
|
||||||
|
"password": password,
|
||||||
|
}, http.StatusOK, &loginResponse)
|
||||||
|
if loginResponse.AccessToken == "" {
|
||||||
|
t.Fatal("admin login did not return access token")
|
||||||
|
}
|
||||||
var smokeGatewayUserID string
|
var smokeGatewayUserID string
|
||||||
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
|
if err := testPool.QueryRow(ctx, `SELECT id::text FROM gateway_users WHERE username = $1`, username).Scan(&smokeGatewayUserID); err != nil {
|
||||||
t.Fatalf("read smoke gateway user id: %v", err)
|
t.Fatalf("read smoke gateway user id: %v", err)
|
||||||
@ -1402,14 +1409,41 @@ func applyMigration(t *testing.T, ctx context.Context, databaseURL string) {
|
|||||||
t.Fatalf("connect migration db: %v", err)
|
t.Fatalf("connect migration db: %v", err)
|
||||||
}
|
}
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
if _, err := pool.Exec(ctx, `
|
||||||
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||||
|
version text PRIMARY KEY,
|
||||||
|
applied_at timestamptz NOT NULL DEFAULT now()
|
||||||
|
);`); err != nil {
|
||||||
|
t.Fatalf("ensure schema migrations: %v", err)
|
||||||
|
}
|
||||||
for _, migrationPath := range migrationFiles {
|
for _, migrationPath := range migrationFiles {
|
||||||
|
version := strings.TrimSuffix(filepath.Base(migrationPath), filepath.Ext(migrationPath))
|
||||||
|
var exists bool
|
||||||
|
if err := pool.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE version = $1)", version).Scan(&exists); err != nil {
|
||||||
|
t.Fatalf("check migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
migration, err := os.ReadFile(migrationPath)
|
migration, err := os.ReadFile(migrationPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err)
|
t.Fatalf("read migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
}
|
}
|
||||||
if _, err := pool.Exec(ctx, string(migration)); err != nil {
|
tx, err := pool.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("begin migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, string(migration)); err != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err)
|
t.Fatalf("apply migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
}
|
}
|
||||||
|
if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations(version) VALUES($1)", version); err != nil {
|
||||||
|
_ = tx.Rollback(ctx)
|
||||||
|
t.Fatalf("record migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
t.Fatalf("commit migration %s: %v", filepath.Base(migrationPath), err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -160,6 +160,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
normalizedModelType := modelType
|
normalizedModelType := modelType
|
||||||
attemptNo := task.AttemptCount
|
attemptNo := task.AttemptCount
|
||||||
var firstPreprocessing parameterPreprocessingLog
|
var firstPreprocessing parameterPreprocessingLog
|
||||||
|
var walletReservations []store.WalletBillingReservation
|
||||||
|
walletReservationFinalized := false
|
||||||
|
defer func() {
|
||||||
|
if !walletReservationFinalized && len(walletReservations) > 0 {
|
||||||
|
_ = s.store.ReleaseTaskBillingReservations(context.WithoutCancel(ctx), walletReservations, "task_not_settled")
|
||||||
|
}
|
||||||
|
}()
|
||||||
if len(candidates) > 0 {
|
if len(candidates) > 0 {
|
||||||
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0])
|
preprocessing := s.preprocessRequestWithScripts(ctx, task.Kind, body, candidates[0])
|
||||||
firstCandidateBody = preprocessing.Body
|
firstCandidateBody = preprocessing.Body
|
||||||
@ -191,15 +198,17 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
return Result{}, err
|
return Result{}, err
|
||||||
}
|
}
|
||||||
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, firstCandidateBody, candidates[0])
|
||||||
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
|
var reserveErr error
|
||||||
if errors.Is(err, store.ErrInsufficientWalletBalance) {
|
walletReservations, reserveErr = s.store.ReserveTaskBilling(ctx, task, user, estimatedBillings)
|
||||||
|
if reserveErr != nil {
|
||||||
|
if errors.Is(reserveErr, store.ErrInsufficientWalletBalance) {
|
||||||
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
attemptNo = s.recordFailedAttempt(ctx, failedAttemptRecord{
|
||||||
Task: task,
|
Task: task,
|
||||||
Body: firstCandidateBody,
|
Body: firstCandidateBody,
|
||||||
Candidate: &candidates[0],
|
Candidate: &candidates[0],
|
||||||
AttemptNo: attemptNo + 1,
|
AttemptNo: attemptNo + 1,
|
||||||
Code: "insufficient_balance",
|
Code: "insufficient_balance",
|
||||||
Cause: err,
|
Cause: reserveErr,
|
||||||
Simulated: task.RunMode == "simulation",
|
Simulated: task.RunMode == "simulation",
|
||||||
Scope: "wallet_balance",
|
Scope: "wallet_balance",
|
||||||
Reason: "wallet_balance_check_failed",
|
Reason: "wallet_balance_check_failed",
|
||||||
@ -207,13 +216,13 @@ func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *aut
|
|||||||
Preprocessing: &firstPreprocessing,
|
Preprocessing: &firstPreprocessing,
|
||||||
ModelType: normalizedModelType,
|
ModelType: normalizedModelType,
|
||||||
})
|
})
|
||||||
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err, parameterPreprocessingMetrics(firstPreprocessing))
|
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", reserveErr.Error(), task.RunMode == "simulation", reserveErr, parameterPreprocessingMetrics(firstPreprocessing))
|
||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
}
|
}
|
||||||
return Result{Task: failed, Output: failed.Result}, err
|
return Result{Task: failed, Output: failed.Result}, reserveErr
|
||||||
}
|
}
|
||||||
return Result{}, err
|
return Result{}, reserveErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil {
|
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": normalizedModelType}, task.RunMode == "simulation"); err != nil {
|
||||||
@ -286,9 +295,18 @@ candidatesLoop:
|
|||||||
if finishErr != nil {
|
if finishErr != nil {
|
||||||
return Result{}, finishErr
|
return Result{}, finishErr
|
||||||
}
|
}
|
||||||
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
if finished.FinalChargeAmount > 0 {
|
||||||
return Result{}, settleErr
|
walletReservationFinalized = true
|
||||||
|
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
||||||
|
return Result{}, settleErr
|
||||||
|
}
|
||||||
|
} else if len(walletReservations) > 0 {
|
||||||
|
if releaseErr := s.store.ReleaseTaskBillingReservations(ctx, walletReservations, "task_billing_zero"); releaseErr != nil {
|
||||||
|
return Result{}, releaseErr
|
||||||
|
}
|
||||||
|
walletReservationFinalized = true
|
||||||
}
|
}
|
||||||
|
walletReservationFinalized = true
|
||||||
if finished.FinalChargeAmount > 0 {
|
if finished.FinalChargeAmount > 0 {
|
||||||
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
||||||
"amount": finished.FinalChargeAmount,
|
"amount": finished.FinalChargeAmount,
|
||||||
@ -695,6 +713,11 @@ func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated boo
|
|||||||
if key == "" {
|
if key == "" {
|
||||||
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
|
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
|
||||||
}
|
}
|
||||||
|
provider := strings.ToLower(strings.TrimSpace(candidate.Provider))
|
||||||
|
baseURL := strings.ToLower(strings.TrimSpace(candidate.BaseURL))
|
||||||
|
if key == "google-gemini" || provider == "gemini" || provider == "google-gemini" || provider == "gemini-openai" || strings.Contains(baseURL, "generativelanguage.googleapis.com") {
|
||||||
|
key = "gemini"
|
||||||
|
}
|
||||||
if client, ok := s.clients[key]; ok {
|
if client, ok := s.clients[key]; ok {
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|||||||
34
apps/api/internal/runner/service_test.go
Normal file
34
apps/api/internal/runner/service_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type namedClient string
|
||||||
|
|
||||||
|
func (namedClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
||||||
|
return clients.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientForMapsGoogleGeminiSpecToGeminiClient(t *testing.T) {
|
||||||
|
service := &Service{clients: map[string]clients.Client{
|
||||||
|
"gemini": namedClient("gemini"),
|
||||||
|
"openai": namedClient("openai"),
|
||||||
|
}}
|
||||||
|
|
||||||
|
candidates := []store.RuntimeModelCandidate{
|
||||||
|
{SpecType: "google-gemini"},
|
||||||
|
{SpecType: "openai", Provider: "gemini-openai"},
|
||||||
|
{SpecType: "openai", BaseURL: "https://generativelanguage.googleapis.com/v1beta/openai"},
|
||||||
|
}
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
client := service.clientFor(candidate, false)
|
||||||
|
if client != namedClient("gemini") {
|
||||||
|
t.Fatalf("Gemini candidate should use gemini client, candidate=%+v got %T %[2]v", candidate, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,38 +0,0 @@
|
|||||||
package runner
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s *Service) ensureWalletBalance(ctx context.Context, user *auth.User, billings []any) error {
|
|
||||||
amounts := map[string]float64{}
|
|
||||||
for _, raw := range billings {
|
|
||||||
line, _ := raw.(map[string]any)
|
|
||||||
if line == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
currency := strings.TrimSpace(stringFromAny(line["currency"]))
|
|
||||||
if currency == "" {
|
|
||||||
currency = "resource"
|
|
||||||
}
|
|
||||||
amounts[currency] = roundPrice(amounts[currency] + floatFromAny(line["amount"]))
|
|
||||||
}
|
|
||||||
for currency, amount := range amounts {
|
|
||||||
if amount <= 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
availability, err := s.store.WalletAvailability(ctx, user, currency, amount)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !availability.Enough {
|
|
||||||
return fmt.Errorf("%w: required %.6f %s, available %.6f", store.ErrInsufficientWalletBalance, amount, currency, availability.AvailableAmount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
202
apps/api/internal/runner/wallet_execute_test.go
Normal file
202
apps/api/internal/runner/wallet_execute_test.go
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
type walletExecuteMockClient struct {
|
||||||
|
calls atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (client *walletExecuteMockClient) Run(context.Context, clients.Request) (clients.Response, error) {
|
||||||
|
client.calls.Add(1)
|
||||||
|
return clients.Response{
|
||||||
|
Result: map[string]any{"mock": true},
|
||||||
|
RequestID: "mock-wallet-execute",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteWithMockClientRejectsConcurrentTasksBeyondWalletBalance(t *testing.T) {
|
||||||
|
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
||||||
|
if databaseURL == "" {
|
||||||
|
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet execute integration test")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db, err := store.Connect(ctx, databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect store: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(db.Close)
|
||||||
|
|
||||||
|
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
tenant, err := db.CreateTenant(ctx, store.GatewayTenantInput{
|
||||||
|
TenantKey: "wallet-execute-" + suffix,
|
||||||
|
Name: "Wallet Execute Test " + suffix,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create tenant: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.DeleteTenant(context.Background(), tenant.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
gatewayUser, err := db.CreateGatewayUser(ctx, store.GatewayUserInput{
|
||||||
|
UserKey: "wallet-execute-user-" + suffix,
|
||||||
|
Username: "wallet_execute_" + suffix,
|
||||||
|
GatewayTenantID: tenant.ID,
|
||||||
|
TenantKey: tenant.TenantKey,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create gateway user: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.DeleteGatewayUser(context.Background(), gatewayUser.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
platform, err := db.CreatePlatform(ctx, store.CreatePlatformInput{
|
||||||
|
Provider: "mock",
|
||||||
|
PlatformKey: "wallet-execute-mock-" + suffix,
|
||||||
|
Name: "Wallet Execute Mock " + suffix,
|
||||||
|
AuthType: "none",
|
||||||
|
Config: map[string]any{"specType": "mock"},
|
||||||
|
Status: "enabled",
|
||||||
|
Priority: 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create mock platform: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = db.DeletePlatform(context.Background(), platform.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
if _, err := db.CreatePlatformModel(ctx, store.CreatePlatformModelInput{
|
||||||
|
PlatformID: platform.ID,
|
||||||
|
ModelName: "mock-wallet-image",
|
||||||
|
ProviderModelName: "mock-wallet-image",
|
||||||
|
ModelType: store.StringList{"image_generate"},
|
||||||
|
DisplayName: "Mock Wallet Image",
|
||||||
|
BillingConfig: map[string]any{
|
||||||
|
"image": map[string]any{"basePrice": 10},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("create mock platform model: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &auth.User{
|
||||||
|
ID: gatewayUser.ID,
|
||||||
|
Source: "gateway",
|
||||||
|
GatewayUserID: gatewayUser.ID,
|
||||||
|
GatewayTenantID: tenant.ID,
|
||||||
|
TenantKey: tenant.TenantKey,
|
||||||
|
Roles: gatewayUser.Roles,
|
||||||
|
}
|
||||||
|
if _, err := db.SetUserWalletBalance(ctx, store.WalletBalanceAdjustmentInput{
|
||||||
|
GatewayUserID: gatewayUser.ID,
|
||||||
|
Currency: "resource",
|
||||||
|
Balance: 10,
|
||||||
|
Reason: "seed wallet execute test",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("seed wallet balance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks := make([]store.GatewayTask, 0, 2)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
task, err := db.CreateTask(ctx, store.CreateTaskInput{
|
||||||
|
Kind: "images.generations",
|
||||||
|
Model: "mock-wallet-image",
|
||||||
|
Request: map[string]any{
|
||||||
|
"count": 1,
|
||||||
|
"prompt": "wallet execute test",
|
||||||
|
},
|
||||||
|
}, user)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create task: %v", err)
|
||||||
|
}
|
||||||
|
tasks = append(tasks, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockClient := &walletExecuteMockClient{}
|
||||||
|
service := New(config.Config{}, db, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||||
|
service.clients["mock"] = mockClient
|
||||||
|
|
||||||
|
type executeResult struct {
|
||||||
|
result Result
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
results := make(chan executeResult, len(tasks))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, task := range tasks {
|
||||||
|
task := task
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
result, err := service.Execute(ctx, task, user)
|
||||||
|
results <- executeResult{result: result, err: err}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
successCount := 0
|
||||||
|
insufficientCount := 0
|
||||||
|
for item := range results {
|
||||||
|
if item.err == nil {
|
||||||
|
successCount++
|
||||||
|
if item.result.Task.Status != "succeeded" {
|
||||||
|
t.Fatalf("successful execution status = %s, want succeeded", item.result.Task.Status)
|
||||||
|
}
|
||||||
|
if !walletExecuteFloatNear(item.result.Task.FinalChargeAmount, 10) {
|
||||||
|
t.Fatalf("successful execution final charge = %f, want 10", item.result.Task.FinalChargeAmount)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(item.err, store.ErrInsufficientWalletBalance) {
|
||||||
|
insufficientCount++
|
||||||
|
if item.result.Task.Status != "failed" || item.result.Task.ErrorCode != "insufficient_balance" {
|
||||||
|
t.Fatalf("insufficient execution task = %+v", item.result.Task)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Fatalf("unexpected execute error: %v", item.err)
|
||||||
|
}
|
||||||
|
if successCount != 1 || insufficientCount != 1 {
|
||||||
|
t.Fatalf("expected one successful mock execution and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
||||||
|
}
|
||||||
|
if got := mockClient.calls.Load(); got != 1 {
|
||||||
|
t.Fatalf("mock client calls = %d, want 1", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := db.GetWalletSummary(ctx, user, "resource")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("get wallet summary: %v", err)
|
||||||
|
}
|
||||||
|
account := summary.PrimaryAccount
|
||||||
|
if !walletExecuteFloatNear(account.Balance, 0) || !walletExecuteFloatNear(account.FrozenBalance, 0) || !walletExecuteFloatNear(account.TotalSpent, 10) {
|
||||||
|
t.Fatalf("wallet after concurrent mock execution balance=%f frozen=%f spent=%f, want 0/0/10", account.Balance, account.FrozenBalance, account.TotalSpent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletExecuteFloatNear(a float64, b float64) bool {
|
||||||
|
delta := a - b
|
||||||
|
if delta < 0 {
|
||||||
|
delta = -delta
|
||||||
|
}
|
||||||
|
return delta < 0.000001
|
||||||
|
}
|
||||||
@ -3,13 +3,13 @@ package store
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TaskListFilter struct {
|
type TaskListFilter struct {
|
||||||
@ -687,14 +687,15 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
|||||||
if currency == "" || currency == "mixed" {
|
if currency == "" || currency == "mixed" {
|
||||||
currency = "resource"
|
currency = "resource"
|
||||||
}
|
}
|
||||||
metadata, _ := json.Marshal(map[string]any{
|
metadataMap := map[string]any{
|
||||||
"taskId": task.ID,
|
"taskId": task.ID,
|
||||||
"kind": task.Kind,
|
"kind": task.Kind,
|
||||||
"model": task.Model,
|
"model": task.Model,
|
||||||
"resolvedModel": task.ResolvedModel,
|
"resolvedModel": task.ResolvedModel,
|
||||||
"billings": task.Billings,
|
"billings": task.Billings,
|
||||||
"billingSummary": task.BillingSummary,
|
"billingSummary": task.BillingSummary,
|
||||||
})
|
}
|
||||||
|
metadata, _ := json.Marshal(metadataMap)
|
||||||
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||||
if _, err := tx.Exec(ctx, `
|
if _, err := tx.Exec(ctx, `
|
||||||
INSERT INTO gateway_wallet_accounts (
|
INSERT INTO gateway_wallet_accounts (
|
||||||
@ -706,42 +707,85 @@ ON CONFLICT (gateway_user_id, currency) DO NOTHING`,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var exists bool
|
var exists bool
|
||||||
|
var accountID string
|
||||||
|
var balanceBefore float64
|
||||||
|
var frozenBefore float64
|
||||||
|
var gatewayTenantID string
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT id::text, balance::float8, frozen_balance::float8, COALESCE(gateway_tenant_id::text, '')
|
||||||
|
FROM gateway_wallet_accounts
|
||||||
|
WHERE gateway_user_id = $1::uuid
|
||||||
|
AND currency = $2
|
||||||
|
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &frozenBefore, &gatewayTenantID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := tx.QueryRow(ctx, `
|
if err := tx.QueryRow(ctx, `
|
||||||
SELECT EXISTS (
|
SELECT EXISTS (
|
||||||
SELECT 1
|
SELECT 1
|
||||||
FROM gateway_wallet_transactions t
|
FROM gateway_wallet_transactions
|
||||||
JOIN gateway_wallet_accounts a ON a.id = t.account_id
|
WHERE account_id = $1::uuid
|
||||||
WHERE a.gateway_user_id = $1::uuid
|
AND idempotency_key = $2
|
||||||
AND a.currency = $2
|
)`, accountID, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
||||||
AND t.idempotency_key = $3
|
|
||||||
)`, task.GatewayUserID, currency, billingIdempotencyKey(task.ID)).Scan(&exists); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if exists {
|
if exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var accountID string
|
|
||||||
var balanceBefore float64
|
amount := roundMoney(task.FinalChargeAmount)
|
||||||
var gatewayTenantID string
|
reservationKey, reservedAmount, err := activeWalletReservation(ctx, tx, accountID, task.ID)
|
||||||
if err := tx.QueryRow(ctx, `
|
if err != nil {
|
||||||
SELECT id::text, balance::float8, COALESCE(gateway_tenant_id::text, '')
|
|
||||||
FROM gateway_wallet_accounts
|
|
||||||
WHERE gateway_user_id = $1::uuid
|
|
||||||
AND currency = $2
|
|
||||||
FOR UPDATE`, task.GatewayUserID, currency).Scan(&accountID, &balanceBefore, &gatewayTenantID); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
amount := roundMoney(task.FinalChargeAmount)
|
reservedAmount = roundMoney(reservedAmount)
|
||||||
|
spendableForTask := roundMoney(balanceBefore - frozenBefore + reservedAmount)
|
||||||
|
if spendableForTask+0.000001 < amount {
|
||||||
|
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, currency, spendableForTask)
|
||||||
|
}
|
||||||
|
|
||||||
balanceAfter := roundMoney(balanceBefore - amount)
|
balanceAfter := roundMoney(balanceBefore - amount)
|
||||||
|
frozenAfter := roundMoney(frozenBefore - reservedAmount)
|
||||||
|
if frozenAfter < 0 {
|
||||||
|
frozenAfter = 0
|
||||||
|
}
|
||||||
if _, err := tx.Exec(ctx, `
|
if _, err := tx.Exec(ctx, `
|
||||||
UPDATE gateway_wallet_accounts
|
UPDATE gateway_wallet_accounts
|
||||||
SET balance = $2,
|
SET balance = $2,
|
||||||
total_spent = total_spent + $3,
|
total_spent = total_spent + $3,
|
||||||
|
frozen_balance = $4,
|
||||||
updated_at = now()
|
updated_at = now()
|
||||||
WHERE id = $1::uuid`, accountID, balanceAfter, amount); err != nil {
|
WHERE id = $1::uuid`, accountID, balanceAfter, amount, frozenAfter); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err := tx.Exec(ctx, `
|
if reservedAmount > 0 {
|
||||||
|
releaseMetadata, _ := json.Marshal(map[string]any{
|
||||||
|
"taskId": task.ID,
|
||||||
|
"reason": "task_billing_settled",
|
||||||
|
"reserved": reservedAmount,
|
||||||
|
"frozenBefore": roundMoney(frozenBefore),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_wallet_transactions (
|
||||||
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
|
||||||
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
|
)
|
||||||
|
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
|
||||||
|
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, reservedAmount, roundMoney(balanceBefore), roundMoney(balanceBefore), billingReservationReleaseIdempotencyKey(reservationKey), task.ID, string(releaseMetadata)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
billingMetadata := mergeObjects(metadataMap, map[string]any{
|
||||||
|
"reservedAmount": reservedAmount,
|
||||||
|
"frozenBefore": roundMoney(frozenBefore),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
metadata, _ = json.Marshal(billingMetadata)
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
INSERT INTO gateway_wallet_transactions (
|
INSERT INTO gateway_wallet_transactions (
|
||||||
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
@ -750,11 +794,10 @@ VALUES (
|
|||||||
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'task_billing',
|
||||||
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
)`,
|
)`,
|
||||||
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata))
|
accountID, firstNonEmpty(gatewayTenantID, task.GatewayTenantID), task.GatewayUserID, amount, roundMoney(balanceBefore), balanceAfter, billingIdempotencyKey(task.ID), task.ID, string(metadata)); err != nil {
|
||||||
if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == "23505" {
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return err
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -92,6 +93,16 @@ type WalletAdjustmentResult struct {
|
|||||||
Transaction GatewayWalletTransaction `json:"transaction"`
|
Transaction GatewayWalletTransaction `json:"transaction"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WalletBillingReservation struct {
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
AccountID string `json:"accountId"`
|
||||||
|
GatewayUserID string `json:"gatewayUserId"`
|
||||||
|
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
||||||
|
Currency string `json:"currency"`
|
||||||
|
Amount float64 `json:"amount"`
|
||||||
|
IdempotencyKey string `json:"idempotencyKey"`
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
|
func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currency string, requiredAmount float64) (WalletAvailability, error) {
|
||||||
gatewayUserID := localGatewayUserID(user)
|
gatewayUserID := localGatewayUserID(user)
|
||||||
if gatewayUserID == "" {
|
if gatewayUserID == "" {
|
||||||
@ -115,6 +126,223 @@ func (s *Store) WalletAvailability(ctx context.Context, user *auth.User, currenc
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) ReserveTaskBilling(ctx context.Context, task GatewayTask, user *auth.User, billings []any) ([]WalletBillingReservation, error) {
|
||||||
|
gatewayUserID := taskGatewayUserID(task, user)
|
||||||
|
if gatewayUserID == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
taskID := strings.TrimSpace(task.ID)
|
||||||
|
if taskID == "" {
|
||||||
|
return nil, fmt.Errorf("task id is required for wallet reservation")
|
||||||
|
}
|
||||||
|
|
||||||
|
amounts := walletBillingAmounts(billings)
|
||||||
|
if len(amounts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reservations := make([]WalletBillingReservation, 0, len(amounts))
|
||||||
|
err := pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||||
|
for currency, rawAmount := range amounts {
|
||||||
|
amount := roundMoney(rawAmount)
|
||||||
|
if amount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.ensureWalletAccount(ctx, tx, gatewayUserID, currency)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
locked, err := lockWalletAccount(ctx, tx, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
activeKey, activeAmount, err := activeWalletReservation(ctx, tx, locked.ID, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if activeAmount > 0 {
|
||||||
|
reservation := WalletBillingReservation{
|
||||||
|
TaskID: taskID,
|
||||||
|
AccountID: locked.ID,
|
||||||
|
GatewayUserID: gatewayUserID,
|
||||||
|
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||||
|
Currency: locked.Currency,
|
||||||
|
Amount: activeAmount,
|
||||||
|
IdempotencyKey: activeKey,
|
||||||
|
}
|
||||||
|
reservations = append(reservations, reservation)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sequence, err := nextWalletReservationSequence(ctx, tx, locked.ID, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
key := billingReservationIdempotencyKey(taskID, locked.Currency, sequence)
|
||||||
|
reservation := WalletBillingReservation{
|
||||||
|
TaskID: taskID,
|
||||||
|
AccountID: locked.ID,
|
||||||
|
GatewayUserID: gatewayUserID,
|
||||||
|
GatewayTenantID: firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||||
|
Currency: locked.Currency,
|
||||||
|
Amount: amount,
|
||||||
|
IdempotencyKey: key,
|
||||||
|
}
|
||||||
|
available := roundMoney(locked.Balance - locked.FrozenBalance)
|
||||||
|
if available+0.000001 < amount {
|
||||||
|
return fmt.Errorf("%w: required %.6f %s, available %.6f", ErrInsufficientWalletBalance, amount, locked.Currency, available)
|
||||||
|
}
|
||||||
|
|
||||||
|
frozenAfter := roundMoney(locked.FrozenBalance + amount)
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
UPDATE gateway_wallet_accounts
|
||||||
|
SET frozen_balance = $2,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
metadata, _ := json.Marshal(map[string]any{
|
||||||
|
"taskId": taskID,
|
||||||
|
"kind": task.Kind,
|
||||||
|
"model": task.Model,
|
||||||
|
"reserved": amount,
|
||||||
|
"balance": roundMoney(locked.Balance),
|
||||||
|
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_wallet_transactions (
|
||||||
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'debit', 'reserve',
|
||||||
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
|
)`,
|
||||||
|
locked.ID,
|
||||||
|
firstNonEmpty(locked.GatewayTenantID, task.GatewayTenantID),
|
||||||
|
gatewayUserID,
|
||||||
|
amount,
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
key,
|
||||||
|
taskID,
|
||||||
|
string(metadata),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reservations = append(reservations, reservation)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return reservations, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Store) ReleaseTaskBillingReservations(ctx context.Context, reservations []WalletBillingReservation, reason string) error {
|
||||||
|
if len(reservations) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if reason == "" {
|
||||||
|
reason = "task_not_settled"
|
||||||
|
}
|
||||||
|
return pgx.BeginFunc(ctx, s.pool, func(tx pgx.Tx) error {
|
||||||
|
for _, reservation := range reservations {
|
||||||
|
if reservation.Amount <= 0 || strings.TrimSpace(reservation.AccountID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reserveKey := strings.TrimSpace(reservation.IdempotencyKey)
|
||||||
|
if reserveKey == "" {
|
||||||
|
reserveKey = billingReservationIdempotencyKey(reservation.TaskID, reservation.Currency, 1)
|
||||||
|
}
|
||||||
|
releaseKey := billingReservationReleaseIdempotencyKey(reserveKey)
|
||||||
|
locked, err := lockWalletAccount(ctx, tx, reservation.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var alreadyReleased bool
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM gateway_wallet_transactions
|
||||||
|
WHERE account_id = $1::uuid
|
||||||
|
AND idempotency_key = $2
|
||||||
|
)`, reservation.AccountID, releaseKey).Scan(&alreadyReleased); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if alreadyReleased {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var storedReservedAmount float64
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT COALESCE((
|
||||||
|
SELECT amount::float8
|
||||||
|
FROM gateway_wallet_transactions
|
||||||
|
WHERE account_id = $1::uuid
|
||||||
|
AND idempotency_key = $2
|
||||||
|
AND transaction_type = 'reserve'
|
||||||
|
LIMIT 1
|
||||||
|
), 0)::float8`, reservation.AccountID, reserveKey).Scan(&storedReservedAmount); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if storedReservedAmount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
amount := roundMoney(storedReservedAmount)
|
||||||
|
frozenAfter := roundMoney(locked.FrozenBalance - amount)
|
||||||
|
if frozenAfter < 0 {
|
||||||
|
frozenAfter = 0
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
UPDATE gateway_wallet_accounts
|
||||||
|
SET frozen_balance = $2,
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
metadata, _ := json.Marshal(map[string]any{
|
||||||
|
"taskId": reservation.TaskID,
|
||||||
|
"reason": reason,
|
||||||
|
"reserved": amount,
|
||||||
|
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||||
|
"frozenAfter": frozenAfter,
|
||||||
|
})
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO gateway_wallet_transactions (
|
||||||
|
account_id, gateway_tenant_id, gateway_user_id, direction, transaction_type,
|
||||||
|
amount, balance_before, balance_after, idempotency_key, reference_type, reference_id, metadata
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
$1::uuid, NULLIF($2, '')::uuid, $3::uuid, 'credit', 'release',
|
||||||
|
$4, $5, $6, $7, 'gateway_task', $8, $9::jsonb
|
||||||
|
)
|
||||||
|
ON CONFLICT (account_id, idempotency_key) WHERE idempotency_key IS NOT NULL DO NOTHING`,
|
||||||
|
locked.ID,
|
||||||
|
locked.GatewayTenantID,
|
||||||
|
locked.GatewayUserID,
|
||||||
|
amount,
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
roundMoney(locked.Balance),
|
||||||
|
releaseKey,
|
||||||
|
reservation.TaskID,
|
||||||
|
string(metadata),
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
|
func (s *Store) GetWalletSummary(ctx context.Context, user *auth.User, currency string) (WalletSummary, error) {
|
||||||
gatewayUserID := localGatewayUserID(user)
|
gatewayUserID := localGatewayUserID(user)
|
||||||
if gatewayUserID == "" {
|
if gatewayUserID == "" {
|
||||||
@ -465,6 +693,124 @@ WHERE gateway_user_id = $1::uuid
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func lockWalletAccount(ctx context.Context, tx pgx.Tx, accountID string) (GatewayWalletAccount, error) {
|
||||||
|
return scanWalletAccount(tx.QueryRow(ctx, `
|
||||||
|
SELECT id::text, COALESCE(gateway_tenant_id::text, ''), gateway_user_id::text,
|
||||||
|
COALESCE(tenant_id, ''), COALESCE(tenant_key, ''), COALESCE(user_id, ''),
|
||||||
|
currency, balance::float8, frozen_balance::float8, total_recharged::float8,
|
||||||
|
total_spent::float8, status, metadata, created_at, updated_at
|
||||||
|
FROM gateway_wallet_accounts
|
||||||
|
WHERE id = $1::uuid
|
||||||
|
FOR UPDATE`, accountID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func activeWalletReservation(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (string, float64, error) {
|
||||||
|
var key string
|
||||||
|
var amount float64
|
||||||
|
err := tx.QueryRow(ctx, `
|
||||||
|
SELECT COALESCE(t.idempotency_key, ''), t.amount::float8
|
||||||
|
FROM gateway_wallet_transactions t
|
||||||
|
WHERE t.account_id = $1::uuid
|
||||||
|
AND t.reference_type = 'gateway_task'
|
||||||
|
AND t.reference_id = $2
|
||||||
|
AND t.transaction_type = 'reserve'
|
||||||
|
AND COALESCE(t.idempotency_key, '') <> ''
|
||||||
|
AND NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM gateway_wallet_transactions r
|
||||||
|
WHERE r.account_id = t.account_id
|
||||||
|
AND r.transaction_type = 'release'
|
||||||
|
AND r.idempotency_key = t.idempotency_key || ':release'
|
||||||
|
)
|
||||||
|
ORDER BY t.created_at DESC
|
||||||
|
LIMIT 1`, accountID, taskID).Scan(&key, &amount)
|
||||||
|
if err == pgx.ErrNoRows {
|
||||||
|
return "", 0, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
return key, roundMoney(amount), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextWalletReservationSequence(ctx context.Context, tx pgx.Tx, accountID string, taskID string) (int, error) {
|
||||||
|
var count int
|
||||||
|
if err := tx.QueryRow(ctx, `
|
||||||
|
SELECT COUNT(*)::int
|
||||||
|
FROM gateway_wallet_transactions
|
||||||
|
WHERE account_id = $1::uuid
|
||||||
|
AND reference_type = 'gateway_task'
|
||||||
|
AND reference_id = $2
|
||||||
|
AND transaction_type = 'reserve'`, accountID, taskID).Scan(&count); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count + 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletBillingAmounts(billings []any) map[string]float64 {
|
||||||
|
amounts := map[string]float64{}
|
||||||
|
for _, raw := range billings {
|
||||||
|
line, _ := raw.(map[string]any)
|
||||||
|
if line == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
amount := roundMoney(walletFloat(line["amount"]))
|
||||||
|
if amount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currency := normalizeWalletCurrency(walletString(line["currency"]))
|
||||||
|
amounts[currency] = roundMoney(amounts[currency] + amount)
|
||||||
|
}
|
||||||
|
return amounts
|
||||||
|
}
|
||||||
|
|
||||||
|
func taskGatewayUserID(task GatewayTask, user *auth.User) string {
|
||||||
|
return firstNonEmpty(strings.TrimSpace(task.GatewayUserID), localGatewayUserID(user))
|
||||||
|
}
|
||||||
|
|
||||||
|
func billingReservationIdempotencyKey(taskID string, currency string, sequence int) string {
|
||||||
|
if sequence <= 0 {
|
||||||
|
sequence = 1
|
||||||
|
}
|
||||||
|
return "task:" + strings.TrimSpace(taskID) + ":wallet-reservation:" + normalizeWalletCurrency(currency) + ":" + strconv.Itoa(sequence)
|
||||||
|
}
|
||||||
|
|
||||||
|
func billingReservationReleaseIdempotencyKey(reservationKey string) string {
|
||||||
|
return strings.TrimSpace(reservationKey) + ":release"
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletString(value any) string {
|
||||||
|
if text, ok := value.(string); ok {
|
||||||
|
return strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletFloat(value any) float64 {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case float64:
|
||||||
|
return typed
|
||||||
|
case float32:
|
||||||
|
return float64(typed)
|
||||||
|
case int:
|
||||||
|
return float64(typed)
|
||||||
|
case int64:
|
||||||
|
return float64(typed)
|
||||||
|
case json.Number:
|
||||||
|
next, _ := typed.Float64()
|
||||||
|
return next
|
||||||
|
case string:
|
||||||
|
next := strings.TrimSpace(typed)
|
||||||
|
if next == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
parsed, _ := strconv.ParseFloat(next, 64)
|
||||||
|
return parsed
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeWalletCurrency(currency string) string {
|
func normalizeWalletCurrency(currency string) string {
|
||||||
currency = strings.TrimSpace(currency)
|
currency = strings.TrimSpace(currency)
|
||||||
if currency == "" {
|
if currency == "" {
|
||||||
|
|||||||
171
apps/api/internal/store/wallet_reservation_test.go
Normal file
171
apps/api/internal/store/wallet_reservation_test.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReserveTaskBillingSerializesConcurrentWalletReservations(t *testing.T) {
|
||||||
|
databaseURL := strings.TrimSpace(os.Getenv("AI_GATEWAY_TEST_DATABASE_URL"))
|
||||||
|
if databaseURL == "" {
|
||||||
|
t.Skip("set AI_GATEWAY_TEST_DATABASE_URL to run the wallet reservation integration test")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
db, err := Connect(ctx, databaseURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("connect store: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tenantID, userID := seedWalletReservationUser(t, ctx, db)
|
||||||
|
if _, err := db.SetUserWalletBalance(ctx, WalletBalanceAdjustmentInput{
|
||||||
|
GatewayUserID: userID,
|
||||||
|
Currency: "resource",
|
||||||
|
Balance: 10,
|
||||||
|
Reason: "seed wallet reservation test",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("seed wallet balance: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstTaskID := newWalletReservationTestUUID(t, ctx, db)
|
||||||
|
secondTaskID := newWalletReservationTestUUID(t, ctx, db)
|
||||||
|
billings := []any{map[string]any{"currency": "resource", "amount": float64(10)}}
|
||||||
|
user := &auth.User{GatewayUserID: userID, GatewayTenantID: tenantID}
|
||||||
|
tasks := []GatewayTask{
|
||||||
|
{ID: firstTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "images.generations", Model: "mock-image"},
|
||||||
|
{ID: secondTaskID, GatewayUserID: userID, GatewayTenantID: tenantID, Kind: "videos.generations", Model: "mock-video"},
|
||||||
|
}
|
||||||
|
|
||||||
|
type reserveResult struct {
|
||||||
|
reservations []WalletBillingReservation
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
results := make(chan reserveResult, len(tasks))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, task := range tasks {
|
||||||
|
task := task
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
reservations, err := db.ReserveTaskBilling(ctx, task, user, billings)
|
||||||
|
results <- reserveResult{reservations: reservations, err: err}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(results)
|
||||||
|
|
||||||
|
var successReservations []WalletBillingReservation
|
||||||
|
successCount := 0
|
||||||
|
insufficientCount := 0
|
||||||
|
for result := range results {
|
||||||
|
if result.err == nil {
|
||||||
|
successCount++
|
||||||
|
successReservations = result.reservations
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(result.err, ErrInsufficientWalletBalance) {
|
||||||
|
insufficientCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Fatalf("unexpected reservation error: %v", result.err)
|
||||||
|
}
|
||||||
|
if successCount != 1 || insufficientCount != 1 {
|
||||||
|
t.Fatalf("expected one successful reservation and one insufficient balance rejection, got success=%d insufficient=%d", successCount, insufficientCount)
|
||||||
|
}
|
||||||
|
if len(successReservations) != 1 || !walletFloatNear(successReservations[0].Amount, 10) {
|
||||||
|
t.Fatalf("unexpected successful reservations: %+v", successReservations)
|
||||||
|
}
|
||||||
|
|
||||||
|
balance, frozen, spent := readWalletReservationAccount(t, ctx, db, userID)
|
||||||
|
if !walletFloatNear(balance, 10) || !walletFloatNear(frozen, 10) || !walletFloatNear(spent, 0) {
|
||||||
|
t.Fatalf("reservation should freeze balance without spending it, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
||||||
|
}
|
||||||
|
|
||||||
|
settleTask := GatewayTask{
|
||||||
|
ID: successReservations[0].TaskID,
|
||||||
|
GatewayUserID: userID,
|
||||||
|
GatewayTenantID: tenantID,
|
||||||
|
Kind: "images.generations",
|
||||||
|
Model: "mock-image",
|
||||||
|
ResolvedModel: "mock-image",
|
||||||
|
Billings: billings,
|
||||||
|
BillingSummary: map[string]any{"currency": "resource", "totalAmount": float64(10)},
|
||||||
|
FinalChargeAmount: 10,
|
||||||
|
}
|
||||||
|
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
||||||
|
t.Fatalf("settle reserved task billing: %v", err)
|
||||||
|
}
|
||||||
|
if err := db.SettleTaskBilling(ctx, settleTask); err != nil {
|
||||||
|
t.Fatalf("settle reserved task billing should be idempotent: %v", err)
|
||||||
|
}
|
||||||
|
balance, frozen, spent = readWalletReservationAccount(t, ctx, db, userID)
|
||||||
|
if !walletFloatNear(balance, 0) || !walletFloatNear(frozen, 0) || !walletFloatNear(spent, 10) {
|
||||||
|
t.Fatalf("settlement should release reservation and debit once, balance=%f frozen=%f spent=%f", balance, frozen, spent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedWalletReservationUser(t *testing.T, ctx context.Context, db *Store) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
suffix := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
var tenantID string
|
||||||
|
if err := db.pool.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_tenants (tenant_key, name)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
RETURNING id::text`, "wallet-reservation-"+suffix, "Wallet Reservation Test "+suffix).Scan(&tenantID); err != nil {
|
||||||
|
t.Fatalf("insert test tenant: %v", err)
|
||||||
|
}
|
||||||
|
var userID string
|
||||||
|
if err := db.pool.QueryRow(ctx, `
|
||||||
|
INSERT INTO gateway_users (user_key, username, gateway_tenant_id, tenant_key, roles)
|
||||||
|
VALUES ($1, $2, $3::uuid, $4, '["basic"]'::jsonb)
|
||||||
|
RETURNING id::text`, "wallet-reservation-user-"+suffix, "wallet_reservation_"+suffix, tenantID, "wallet-reservation-"+suffix).Scan(&userID); err != nil {
|
||||||
|
t.Fatalf("insert test user: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cleanupCtx := context.Background()
|
||||||
|
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_users WHERE id = $1::uuid`, userID)
|
||||||
|
_, _ = db.pool.Exec(cleanupCtx, `DELETE FROM gateway_tenants WHERE id = $1::uuid`, tenantID)
|
||||||
|
})
|
||||||
|
return tenantID, userID
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWalletReservationTestUUID(t *testing.T, ctx context.Context, db *Store) string {
|
||||||
|
t.Helper()
|
||||||
|
var id string
|
||||||
|
if err := db.pool.QueryRow(ctx, `SELECT gen_random_uuid()::text`).Scan(&id); err != nil {
|
||||||
|
t.Fatalf("generate uuid: %v", err)
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func readWalletReservationAccount(t *testing.T, ctx context.Context, db *Store, userID string) (float64, float64, float64) {
|
||||||
|
t.Helper()
|
||||||
|
var balance float64
|
||||||
|
var frozen float64
|
||||||
|
var spent float64
|
||||||
|
if err := db.pool.QueryRow(ctx, `
|
||||||
|
SELECT balance::float8, frozen_balance::float8, total_spent::float8
|
||||||
|
FROM gateway_wallet_accounts
|
||||||
|
WHERE gateway_user_id = $1::uuid
|
||||||
|
AND currency = 'resource'`, userID).Scan(&balance, &frozen, &spent); err != nil {
|
||||||
|
t.Fatalf("read wallet account: %v", err)
|
||||||
|
}
|
||||||
|
return balance, frozen, spent
|
||||||
|
}
|
||||||
|
|
||||||
|
func walletFloatNear(a float64, b float64) bool {
|
||||||
|
delta := a - b
|
||||||
|
if delta < 0 {
|
||||||
|
delta = -delta
|
||||||
|
}
|
||||||
|
return delta < 0.000001
|
||||||
|
}
|
||||||
@ -684,11 +684,13 @@ export interface VideoGenerationContent {
|
|||||||
};
|
};
|
||||||
video_url?: {
|
video_url?: {
|
||||||
url: string;
|
url: string;
|
||||||
|
mime_type?: string;
|
||||||
refer_type?: 'feature' | 'base';
|
refer_type?: 'feature' | 'base';
|
||||||
keep_original_sound?: 'yes' | 'no';
|
keep_original_sound?: 'yes' | 'no';
|
||||||
};
|
};
|
||||||
audio_url?: {
|
audio_url?: {
|
||||||
url: string;
|
url: string;
|
||||||
|
mime_type?: string;
|
||||||
};
|
};
|
||||||
role?: VideoGenerationContentRole;
|
role?: VideoGenerationContentRole;
|
||||||
shot_index?: number;
|
shot_index?: number;
|
||||||
|
|||||||
@ -32,8 +32,8 @@ export interface PlaygroundUpload {
|
|||||||
export type OpenAIChatContentPart =
|
export type OpenAIChatContentPart =
|
||||||
| { type: 'text'; text: string }
|
| { type: 'text'; text: string }
|
||||||
| { type: 'image_url'; image_url: { url: string } }
|
| { type: 'image_url'; image_url: { url: string } }
|
||||||
| { type: 'video_url'; video_url: { url: string } }
|
| { type: 'video_url'; video_url: { mime_type?: string; url: string } }
|
||||||
| { type: 'audio_url'; audio_url: { url: string } }
|
| { type: 'audio_url'; audio_url: { mime_type?: string; url: string } }
|
||||||
| { type: 'file_url'; file_url: { filename: string; url: string } };
|
| { type: 'file_url'; file_url: { filename: string; url: string } };
|
||||||
|
|
||||||
export const mediaUploadAccept = 'image/*,video/*,audio/*';
|
export const mediaUploadAccept = 'image/*,video/*,audio/*';
|
||||||
@ -518,11 +518,17 @@ export function openAIContentFromPromptAndUploads(prompt: string, uploads: Playg
|
|||||||
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
|
function openAIContentPartFromUpload(item: PlaygroundUpload): OpenAIChatContentPart | undefined {
|
||||||
if (!item.url) return undefined;
|
if (!item.url) return undefined;
|
||||||
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
|
if (item.kind === 'image') return { type: 'image_url', image_url: { url: item.url } };
|
||||||
if (item.kind === 'video') return { type: 'video_url', video_url: { url: item.url } };
|
if (item.kind === 'video') return { type: 'video_url', video_url: mediaURLPayload(item) };
|
||||||
if (item.kind === 'audio') return { type: 'audio_url', audio_url: { url: item.url } };
|
if (item.kind === 'audio') return { type: 'audio_url', audio_url: mediaURLPayload(item) };
|
||||||
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
|
return { type: 'file_url', file_url: { filename: item.name, url: item.url } };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function mediaURLPayload(item: PlaygroundUpload) {
|
||||||
|
const payload: { mime_type?: string; url: string } = { url: item.url };
|
||||||
|
if (item.contentType) payload.mime_type = item.contentType;
|
||||||
|
return payload;
|
||||||
|
}
|
||||||
|
|
||||||
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
|
export function mediaUploadRequestPayload(uploads: PlaygroundUpload[], mode: Exclude<PlaygroundMode, 'chat'>) {
|
||||||
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
|
const images = uploads.filter((item) => item.kind === 'image').map((item) => item.url);
|
||||||
const payload: Record<string, string | string[]> = {};
|
const payload: Record<string, string | string[]> = {};
|
||||||
@ -570,10 +576,10 @@ function videoGenerationContentFromUpload(item: PlaygroundUpload): VideoGenerati
|
|||||||
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
|
return { type: 'image_url', role: 'reference_image', image_url: { url: item.url } };
|
||||||
}
|
}
|
||||||
if (item.kind === 'video') {
|
if (item.kind === 'video') {
|
||||||
return { type: 'video_url', role: 'reference_video', video_url: { url: item.url, refer_type: 'feature' } };
|
return { type: 'video_url', role: 'reference_video', video_url: { ...mediaURLPayload(item), refer_type: 'feature' } };
|
||||||
}
|
}
|
||||||
if (item.kind === 'audio') {
|
if (item.kind === 'audio') {
|
||||||
return { type: 'audio_url', role: 'reference_audio', audio_url: { url: item.url } };
|
return { type: 'audio_url', role: 'reference_audio', audio_url: mediaURLPayload(item) };
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user