From 37d0f919e592cf6d4293082be0e133fe7c075686 Mon Sep 17 00:00:00 2001 From: wangbo Date: Fri, 15 May 2026 01:53:52 +0800 Subject: [PATCH] feat: add gateway billing estimate and rate limit details --- apps/api/internal/httpapi/handlers.go | 139 +++++++++- .../httpapi/rate_limit_error_detail_test.go | 72 +++++ apps/api/internal/httpapi/response.go | 7 + apps/api/internal/runner/limits.go | 31 ++- apps/api/internal/runner/pricing.go | 259 ++++++++++++++++-- apps/api/internal/runner/pricing_test.go | 132 +++++++++ apps/api/internal/store/rate_limits.go | 63 ++++- apps/api/internal/store/runtime_types.go | 24 +- apps/api/internal/store/user_group_policy.go | 5 +- apps/web/src/api.ts | 5 +- apps/web/src/pages/PlaygroundPage.tsx | 133 ++++++++- .../pages/admin/IdentityManagementPanels.tsx | 77 +++++- apps/web/src/styles/playground.css | 9 + packages/contracts/src/index.ts | 26 ++ 14 files changed, 916 insertions(+), 66 deletions(-) create mode 100644 apps/api/internal/httpapi/rate_limit_error_detail_test.go create mode 100644 apps/api/internal/runner/pricing_test.go diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index 16ac3fd..a8ea8de 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -597,7 +597,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { status := statusFromRunError(runErr) errorPayload := map[string]any{ "code": runErrorCode(runErr), - "message": runErr.Error(), + "message": runErrorMessage(runErr), "status": status, } if result.Task.ID != "" { @@ -606,6 +606,9 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { if result.Task.RequestID != "" { errorPayload["requestId"] = result.Task.RequestID } + for key, value := range runErrorDetails(runErr) { + errorPayload[key] = value + } sendSSE(w, "error", map[string]any{"error": errorPayload}) if flusher != nil { flusher.Flush() @@ -626,7 +629,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler { if !requestStillConnected(r) { return } - writeError(w, statusFromRunError(runErr), runErr.Error(), runErrorCode(runErr)) + writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr)) return } if !requestStillConnected(r) { @@ -742,6 +745,138 @@ func runErrorCode(err error) string { return clients.ErrorCode(err) } +func runErrorMessage(err error) string { + if err == nil { + return "" + } + if summary := rateLimitErrorSummary(err); summary != "" { + return err.Error() + ";" + summary + } + return err.Error() +} + +func runErrorDetails(err error) map[string]any { + if detail := rateLimitErrorDetail(err); len(detail) > 0 { + return map[string]any{"rateLimit": detail} + } + return nil +} + +func rateLimitErrorSummary(err error) string { + var limitErr *store.RateLimitExceededError + if !errors.As(err, &limitErr) { + return "" + } + scopeLabel := "限流对象" + switch limitErr.ScopeType { + case "user_group": + scopeLabel = "用户组" + case "platform_model": + scopeLabel = "平台模型" + } + scopeName := strings.TrimSpace(limitErr.ScopeName) + if scopeName == "" { + scopeName = strings.TrimSpace(limitErr.ScopeKey) + } + if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); limitErr.ScopeType == "user_group" && groupKey != "" && groupKey != scopeName { + scopeName = fmt.Sprintf("%s(%s)", scopeName, groupKey) + } + projected := limitErr.Projected + if projected <= 0 { + projected = limitErr.Current + limitErr.Amount + } + parts := []string{ + fmt.Sprintf("限流摘要:%s %s 的 %s 超限", scopeLabel, scopeName, limitErr.Metric), + fmt.Sprintf("当前 %s,本次 %s,预计 %s,限制 %s", formatRateLimitValue(limitErr.Current), formatRateLimitValue(limitErr.Amount), formatRateLimitValue(projected), formatRateLimitValue(limitErr.Limit)), + } + if limitErr.WindowSeconds > 0 { + parts = append(parts, fmt.Sprintf("窗口 %d 秒", limitErr.WindowSeconds)) + } + if limitErr.RetryAfter > 0 { + parts = append(parts, fmt.Sprintf("约%s后可重试", formatRateLimitDuration(limitErr.RetryAfter))) + } else if !limitErr.Retryable { + parts = append(parts, "该请求超过单次限额,不能排队重试") + } + return strings.Join(parts, ",") +} + +func rateLimitErrorDetail(err error) map[string]any { + var limitErr *store.RateLimitExceededError + if !errors.As(err, &limitErr) { + return nil + } + detail := map[string]any{ + "scopeType": limitErr.ScopeType, + "scopeKey": limitErr.ScopeKey, + "scopeName": limitErr.ScopeName, + "metric": limitErr.Metric, + "limit": limitErr.Limit, + "amount": limitErr.Amount, + "current": limitErr.Current, + "used": limitErr.Used, + "reserved": limitErr.Reserved, + "projected": limitErr.Projected, + "windowSeconds": limitErr.WindowSeconds, + "retryable": limitErr.Retryable, + "exceeded": map[string]any{ + "metric": limitErr.Metric, + "current": limitErr.Current, + "amount": limitErr.Amount, + "projected": limitErr.Projected, + "limit": limitErr.Limit, + }, + } + if limitErr.RetryAfter > 0 { + detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds() + } + if !limitErr.ResetAt.IsZero() { + detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano) + } + if len(limitErr.Policy) > 0 { + detail["rateLimitPolicy"] = limitErr.Policy + if matchedRule := matchedRateLimitRule(limitErr.Policy, limitErr.Metric); len(matchedRule) > 0 { + detail["matchedRule"] = matchedRule + } + } + if len(limitErr.ScopeMetadata) > 0 { + detail["scopeMetadata"] = limitErr.ScopeMetadata + } + if limitErr.ScopeType == "user_group" { + userGroup := map[string]any{ + "id": limitErr.ScopeKey, + "name": limitErr.ScopeName, + } + if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); groupKey != "" { + userGroup["groupKey"] = groupKey + } + detail["userGroup"] = userGroup + } + return detail +} + +func formatRateLimitValue(value float64) string { + return strconv.FormatFloat(value, 'f', -1, 64) +} + +func formatRateLimitDuration(duration time.Duration) string { + if duration < time.Second { + return strconv.FormatInt(duration.Milliseconds(), 10) + "毫秒" + } + seconds := duration.Seconds() + return strconv.FormatFloat(seconds, 'f', -1, 64) + "秒" +} + +func matchedRateLimitRule(policy map[string]any, metric string) map[string]any { + rules, _ := policy["rules"].([]any) + for _, rawRule := range rules { + rule, _ := rawRule.(map[string]any) + if stringValue(rule["metric"]) == metric { + return rule + } + } + return nil +} + func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) { user, ok := auth.UserFromContext(r.Context()) if !ok { diff --git a/apps/api/internal/httpapi/rate_limit_error_detail_test.go b/apps/api/internal/httpapi/rate_limit_error_detail_test.go new file mode 100644 index 0000000..fe1222b --- /dev/null +++ b/apps/api/internal/httpapi/rate_limit_error_detail_test.go @@ -0,0 +1,72 @@ +package httpapi + +import ( + "strings" + "testing" + "time" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestRateLimitErrorDetailIncludesUserGroupAndExceededMetric(t *testing.T) { + resetAt := time.Date(2026, 5, 15, 10, 30, 0, 0, time.UTC) + detail := rateLimitErrorDetail(&store.RateLimitExceededError{ + ScopeType: "user_group", + ScopeKey: "group-1", + ScopeName: "VIP 用户组", + ScopeMetadata: map[string]any{"groupKey": "vip"}, + Metric: "rpm", + Limit: 2, + Amount: 1, + Current: 2, + Used: 1, + Reserved: 1, + Projected: 3, + WindowSeconds: 60, + ResetAt: resetAt, + RetryAfter: 5 * time.Second, + Retryable: true, + Policy: map[string]any{ + "rules": []any{ + map[string]any{"metric": "rpm", "limit": float64(2), "windowSeconds": float64(60)}, + }, + }, + }) + if detail["metric"] != "rpm" || detail["projected"] != float64(3) || detail["limit"] != float64(2) { + t.Fatalf("unexpected exceeded detail: %+v", detail) + } + userGroup, _ := detail["userGroup"].(map[string]any) + if userGroup["id"] != "group-1" || userGroup["groupKey"] != "vip" || userGroup["name"] != "VIP 用户组" { + t.Fatalf("missing user group detail: %+v", detail) + } + matchedRule, _ := detail["matchedRule"].(map[string]any) + if matchedRule["metric"] != "rpm" { + t.Fatalf("missing matched rule: %+v", detail) + } + if detail["retryAfterMs"] != int64(5000) || detail["resetAt"] != resetAt.Format(time.RFC3339Nano) { + t.Fatalf("missing retry/reset detail: %+v", detail) + } +} + +func TestRunErrorMessageIncludesRateLimitSummary(t *testing.T) { + message := runErrorMessage(&store.RateLimitExceededError{ + ScopeType: "user_group", + ScopeKey: "group-1", + ScopeName: "VIP 用户组", + ScopeMetadata: map[string]any{"groupKey": "vip"}, + Metric: "rpm", + Limit: 2, + Amount: 1, + Current: 2, + Projected: 3, + WindowSeconds: 60, + RetryAfter: 5 * time.Second, + Retryable: true, + Message: "rate limit exceeded: rpm window has no remaining capacity", + }) + for _, expected := range []string{"限流摘要", "用户组 VIP 用户组(vip)", "rpm 超限", "当前 2", "本次 1", "预计 3", "限制 2", "窗口 60 秒", "约5秒后可重试"} { + if !strings.Contains(message, expected) { + t.Fatalf("message %q should contain %q", message, expected) + } + } +} diff --git a/apps/api/internal/httpapi/response.go b/apps/api/internal/httpapi/response.go index 7939d48..5f3a004 100644 --- a/apps/api/internal/httpapi/response.go +++ b/apps/api/internal/httpapi/response.go @@ -14,6 +14,10 @@ func writeJSON(w http.ResponseWriter, status int, value any) { } func writeError(w http.ResponseWriter, status int, message string, codes ...string) { + writeErrorWithDetails(w, status, message, nil, codes...) +} + +func writeErrorWithDetails(w http.ResponseWriter, status int, message string, details map[string]any, codes ...string) { errorPayload := map[string]any{ "message": message, "status": status, @@ -23,6 +27,9 @@ func writeError(w http.ResponseWriter, status int, message string, codes ...stri errorPayload["code"] = code } } + for key, value := range details { + errorPayload[key] = value + } writeJSON(w, status, map[string]any{"error": errorPayload}) } diff --git a/apps/api/internal/runner/limits.go b/apps/api/internal/runner/limits.go index 24a95af..9a47484 100644 --- a/apps/api/internal/runner/limits.go +++ b/apps/api/internal/runner/limits.go @@ -52,9 +52,31 @@ func isLocalRateLimitError(err error) bool { func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation { out := make([]store.RateLimitReservation, 0) - out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...) + out = append(out, reservationsFromPolicy( + "platform_model", + candidate.PlatformModelID, + firstNonEmptyString(candidate.DisplayName, candidate.ModelAlias, candidate.ModelName), + map[string]any{ + "platformId": candidate.PlatformID, + "platformName": candidate.PlatformName, + "modelAlias": candidate.ModelAlias, + "modelName": candidate.ModelName, + }, + effectiveRateLimitPolicy(candidate), + body, + )...) if group, err := s.store.ResolveUserGroupPolicy(ctx, user); err == nil && group.ID != "" { - out = append(out, reservationsFromPolicy("user_group", group.ID, group.RateLimitPolicy, body)...) + out = append(out, reservationsFromPolicy( + "user_group", + group.ID, + firstNonEmptyString(group.Name, group.GroupKey), + map[string]any{ + "groupKey": group.GroupKey, + "name": group.Name, + }, + group.RateLimitPolicy, + body, + )...) } return out } @@ -90,7 +112,7 @@ func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any return policy } -func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation { +func reservationsFromPolicy(scopeType string, scopeKey string, scopeName string, scopeMetadata map[string]any, policy map[string]any, body map[string]any) []store.RateLimitReservation { if scopeKey == "" || !hasRules(policy) { return nil } @@ -108,11 +130,14 @@ func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string out = append(out, store.RateLimitReservation{ ScopeType: scopeType, ScopeKey: scopeKey, + ScopeName: scopeName, + ScopeMetadata: scopeMetadata, Metric: metric, Limit: limit, Amount: amount, WindowSeconds: int(floatFromAny(rule["windowSeconds"])), LeaseTTLSeconds: int(floatFromAny(rule["leaseTtlSeconds"])), + Policy: policy, }) } return out diff --git a/apps/api/internal/runner/pricing.go b/apps/api/internal/runner/pricing.go index 0eaf684..5c8ac6f 100644 --- a/apps/api/internal/runner/pricing.go +++ b/apps/api/internal/runner/pricing.go @@ -11,8 +11,10 @@ import ( ) type EstimateResult struct { - Items []any `json:"items"` - Resolver string `json:"resolver"` + Items []any `json:"items"` + Resolver string `json:"resolver"` + TotalAmount float64 `json:"totalAmount"` + Currency string `json:"currency"` } func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) { @@ -23,9 +25,12 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body } candidate := candidates[0] body = preprocessRequest(kind, body, candidate) + items := s.estimatedBillings(ctx, user, kind, body, candidate) return EstimateResult{ - Items: s.estimatedBillings(ctx, user, kind, body, candidate), - Resolver: "effective-pricing-v1", + Items: items, + Resolver: "effective-pricing-v1", + TotalAmount: totalBillingAmount(items), + Currency: billingCurrency(items), }, nil } @@ -60,10 +65,7 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo billingLine(candidate, "text_output", "1k_tokens", outputTokens, outputAmount, discount, simulated), } } - count := int(floatFromAny(body["n"])) - if count <= 0 { - count = 1 - } + count := requestOutputCount(body) resource := "image" unit := "image" baseKey := "imageBase" @@ -73,8 +75,24 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo } if kind == "videos.generations" { resource = "video" - unit = "video" + unit = "5s_video" baseKey = "videoBase" + duration := requestDurationSeconds(body) + durationUnits := math.Max(1, math.Ceil(duration/5)) + amount := float64(count) * + durationUnits * + resourcePrice(config, resource, baseKey, "basePrice") * + resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * + resourceWeight(config, resource, "audioWeights", boolWeightKey(boolishValue(body["audio"]))) * + resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) * + resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body))) * + discount + return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{ + "count": count, + "durationSeconds": duration, + "durationUnit": "5s", + "durationUnitCount": durationUnits, + })} } amount := float64(count) * resourcePrice(config, resource, baseKey, "basePrice") * resourceWeight(config, resource, "qualityWeights", stringFromMap(body, "quality")) * resourceWeight(config, resource, "sizeWeights", stringFromMap(body, "size")) * resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * discount return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)} @@ -109,17 +127,23 @@ func effectiveDiscount(ctx context.Context, db *store.Store, user *auth.User, ca if discount <= 0 { discount = 1 } - if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil { - groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"]) - if groupDiscount > 0 { - discount *= groupDiscount + if db != nil { + if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil { + groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"]) + if groupDiscount > 0 { + discount *= groupDiscount + } } } return discount } func billingLine(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool) map[string]any { - return map[string]any{ + return billingLineWithDetails(candidate, resourceType, unit, quantity, amount, discount, simulated, nil) +} + +func billingLineWithDetails(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool, details map[string]any) map[string]any { + line := map[string]any{ "model": candidate.ModelName, "modelAlias": candidate.ModelAlias, "provider": candidate.Provider, @@ -133,6 +157,10 @@ func billingLine(candidate store.RuntimeModelCandidate, resourceType string, uni "discountFactor": discount, "simulated": simulated, } + for key, value := range details { + line[key] = value + } + return line } func price(config map[string]any, key string) float64 { @@ -177,7 +205,16 @@ func weighted(config map[string]any, key string, name string) float64 { } func resourceWeight(config map[string]any, resource string, key string, name string) float64 { - if value := weighted(config, key, name); value != 1 { + keys := weightKeyAliases(key) + names := weightValueAliases(key, name) + for _, candidateKey := range keys { + for _, candidateName := range names { + if value := weighted(config, candidateKey, candidateName); value != 1 { + return value + } + } + } + if value := dynamicWeight(config["dynamicWeight"], keys, names); value != 1 { return value } if strings.TrimSpace(name) == "" { @@ -187,19 +224,201 @@ func resourceWeight(config map[string]any, resource string, key string, name str if len(resourceConfig) == 0 && resource == "image_edit" { resourceConfig, _ = config["image"].(map[string]any) } - if weights, ok := resourceConfig["dynamicWeight"].(map[string]any); ok { - if value := floatFromAny(weights[name]); value > 0 { - return value + if value := dynamicWeight(resourceConfig["dynamicWeight"], keys, names); value != 1 { + return value + } + for _, candidateKey := range keys { + if weights, ok := resourceConfig[candidateKey].(map[string]any); ok { + for _, candidateName := range names { + if value := floatFromAny(weights[candidateName]); value > 0 { + return value + } + } } } - if weights, ok := resourceConfig[key].(map[string]any); ok { - if value := floatFromAny(weights[name]); value > 0 { + return 1 +} + +func dynamicWeight(value any, keys []string, names []string) float64 { + if len(names) == 0 { + return 1 + } + weights, _ := value.(map[string]any) + if len(weights) == 0 { + return 1 + } + for _, name := range names { + if direct := floatFromAny(weights[name]); direct > 0 { + return direct + } + } + for _, key := range keys { + if nested, ok := weights[key].(map[string]any); ok { + for _, name := range names { + if nestedValue := floatFromAny(nested[name]); nestedValue > 0 { + return nestedValue + } + } + } + } + return 1 +} + +func weightKeyAliases(key string) []string { + switch key { + case "qualityWeights": + return []string{"qualityWeights", "qualityFactors"} + case "resolutionWeights": + return []string{"resolutionWeights", "resolutionFactors"} + case "audioWeights": + return []string{"audioWeights", "audioFactors"} + case "referenceVideoWeights": + return []string{"referenceVideoWeights", "referenceVideoFactors"} + case "voiceSpecifiedWeights": + return []string{"voiceSpecifiedWeights", "voiceSpecifiedFactors"} + default: + return []string{key} + } +} + +func weightValueAliases(key string, name string) []string { + name = strings.TrimSpace(name) + if name == "" { + return nil + } + switch key { + case "audioWeights": + return []string{name, "audio-" + name} + case "referenceVideoWeights": + return []string{name, "reference-video-" + name} + case "voiceSpecifiedWeights": + return []string{name, "voice-specified-" + name} + default: + return []string{name} + } +} + +func requestOutputCount(body map[string]any) int { + for _, key := range []string{"n", "count", "batch_size", "batchSize"} { + if value := int(math.Ceil(floatFromAny(body[key]))); value > 0 { return value } } return 1 } +func requestDurationSeconds(body map[string]any) float64 { + for _, key := range []string{"duration", "durationSeconds", "duration_seconds"} { + if value := floatFromAny(body[key]); value > 0 { + return value + } + } + for _, value := range body { + items, ok := value.([]any) + if !ok || len(items) == 0 { + continue + } + total := 0.0 + allDurationItems := true + for _, item := range items { + record, ok := item.(map[string]any) + if !ok { + allDurationItems = false + break + } + duration := floatFromAny(record["duration"]) + if duration <= 0 { + allDurationItems = false + break + } + total += duration + } + if allDurationItems && total > 0 { + return total + } + } + return 5 +} + +func requestHasReferenceVideo(body map[string]any) bool { + if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) { + return true + } + if firstNonEmptyStringValue(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") != "" { + return true + } + content, _ := body["content"].([]any) + for _, item := range content { + record, _ := item.(map[string]any) + if len(record) == 0 { + continue + } + itemType := strings.TrimSpace(stringFromAny(record["type"])) + role := strings.TrimSpace(stringFromAny(record["role"])) + if itemType == "video_url" || role == "video_feature" || role == "video_base" || role == "reference_video" { + return true + } + } + return false +} + +func requestHasVoiceID(body map[string]any) bool { + return boolishValue(body["audio"]) && firstNonEmptyStringValue(body, "voice_id", "voiceId") != "" +} + +func boolWeightKey(value bool) string { + if value { + return "true" + } + return "false" +} + +func boolishValue(value any) bool { + switch typed := value.(type) { + case bool: + return typed + case string: + switch strings.ToLower(strings.TrimSpace(typed)) { + case "true", "1", "yes", "on": + return true + default: + return false + } + case int: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + default: + return false + } +} + +func hasNonEmptyArray(value any) bool { + items, ok := value.([]any) + return ok && len(items) > 0 +} + +func totalBillingAmount(items []any) float64 { + total := 0.0 + for _, raw := range items { + line, _ := raw.(map[string]any) + total += floatFromAny(line["amount"]) + } + return roundPrice(total) +} + +func billingCurrency(items []any) string { + for _, raw := range items { + line, _ := raw.(map[string]any) + if currency := stringFromAny(line["currency"]); currency != "" { + return currency + } + } + return "resource" +} + func firstNonEmptyString(values ...string) string { for _, value := range values { if strings.TrimSpace(value) != "" { diff --git a/apps/api/internal/runner/pricing_test.go b/apps/api/internal/runner/pricing_test.go new file mode 100644 index 0000000..45280ea --- /dev/null +++ b/apps/api/internal/runner/pricing_test.go @@ -0,0 +1,132 @@ +package runner + +import ( + "context" + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" +) + +func TestImageBillingEstimateUsesCountResolutionAndQuality(t *testing.T) { + service := &Service{} + candidate := store.RuntimeModelCandidate{ + ModelName: "image-model", + BaseBillingConfig: map[string]any{ + "image": map[string]any{ + "basePrice": 10, + "dynamicWeight": map[string]any{ + "resolutionFactors": map[string]any{"2K": 1.5}, + "qualityFactors": map[string]any{"high": 1.5}, + }, + }, + }, + } + + items := service.billings(context.Background(), nil, "images.generations", map[string]any{ + "count": 2, + "quality": "high", + "resolution": "2K", + }, candidate, clients.Response{}, true) + + line := firstBillingLine(t, items) + if got, want := floatFromAny(line["amount"]), 45.0; got != want { + t.Fatalf("image estimated amount = %v, want %v", got, want) + } + if got, want := line["quantity"], 2; got != want { + t.Fatalf("image quantity = %v, want %v", got, want) + } +} + +func TestVideoBillingEstimateUsesFiveSecondUnitsAndDynamicWeights(t *testing.T) { + service := &Service{} + candidate := store.RuntimeModelCandidate{ + ModelName: "video-model", + BaseBillingConfig: map[string]any{ + "video": map[string]any{ + "basePrice": 100, + "dynamicWeight": map[string]any{ + "resolutionWeights": map[string]any{"1080p": 1.5}, + "audioWeights": map[string]any{"true": 2}, + "referenceVideoWeights": map[string]any{"true": 1.5}, + "voiceSpecifiedWeights": map[string]any{"true": 1.2}, + "unusedCompatibilityField": map[string]any{"true": 99}, + }, + }, + }, + } + + items := service.billings(context.Background(), nil, "videos.generations", map[string]any{ + "audio": true, + "duration": 12, + "resolution": "1080p", + "voice_id": "voice-a", + "content": []any{ + map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/reference.mp4"}}, + }, + }, candidate, clients.Response{}, true) + + line := firstBillingLine(t, items) + if got, want := floatFromAny(line["amount"]), 1620.0; got != want { + t.Fatalf("video estimated amount = %v, want %v", got, want) + } + if got, want := floatFromAny(line["durationUnitCount"]), 3.0; got != want { + t.Fatalf("video duration units = %v, want %v", got, want) + } + if got, want := line["quantity"], 3; got != want { + t.Fatalf("video quantity = %v, want %v", got, want) + } +} + +func TestVideoBillingEstimateSupportsServerMainStyleDynamicKeys(t *testing.T) { + service := &Service{} + candidate := store.RuntimeModelCandidate{ + ModelName: "legacy-video-model", + BaseBillingConfig: map[string]any{ + "videoBase": 100, + "video": map[string]any{ + "dynamicWeight": map[string]any{ + "720p": 1.25, + "audio-true": 2, + "reference-video-true": 1.5, + }, + }, + }, + } + + items := service.billings(context.Background(), nil, "videos.generations", map[string]any{ + "audio": "true", + "duration": 5, + "resolution": "720p", + "video_list": []any{map[string]any{"url": "https://example.com/reference.mp4"}}, + }, candidate, clients.Response{}, true) + + line := firstBillingLine(t, items) + if got, want := floatFromAny(line["amount"]), 375.0; got != want { + t.Fatalf("legacy video estimated amount = %v, want %v", got, want) + } +} + +func TestVideoDurationEstimateSumsMultiShotDurations(t *testing.T) { + duration := requestDurationSeconds(map[string]any{ + "multi_prompt": []any{ + map[string]any{"prompt": "shot 1", "duration": 3}, + map[string]any{"prompt": "shot 2", "duration": 7}, + }, + }) + if duration != 10 { + t.Fatalf("multi-shot duration = %v, want 10", duration) + } +} + +func firstBillingLine(t *testing.T, items []any) map[string]any { + t.Helper() + if len(items) != 1 { + t.Fatalf("items length = %d, want 1: %+v", len(items), items) + } + line, ok := items[0].(map[string]any) + if !ok { + t.Fatalf("item type = %T, want map[string]any", items[0]) + } + return line +} diff --git a/apps/api/internal/store/rate_limits.go b/apps/api/internal/store/rate_limits.go index e564e61..af61f2b 100644 --- a/apps/api/internal/store/rate_limits.go +++ b/apps/api/internal/store/rate_limits.go @@ -31,9 +31,18 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID } if reservation.Metric == "" || reservation.Amount > reservation.Limit { return RateLimitResult{}, &RateLimitExceededError{ - Metric: reservation.Metric, - Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit), - Retryable: false, + ScopeType: reservation.ScopeType, + ScopeKey: reservation.ScopeKey, + ScopeName: reservation.ScopeName, + ScopeMetadata: reservation.ScopeMetadata, + Metric: reservation.Metric, + Limit: reservation.Limit, + Amount: reservation.Amount, + Projected: reservation.Amount, + WindowSeconds: reservation.WindowSeconds, + Policy: reservation.Policy, + Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit), + Retryable: false, } } if reservation.WindowSeconds <= 0 { @@ -78,10 +87,22 @@ WHERE scope_type = $1 } if active+reservation.Amount > reservation.Limit { return "", &RateLimitExceededError{ - Metric: reservation.Metric, - Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit), - RetryAfter: concurrencyRetryAfter(nextAvailableAt), - Retryable: true, + ScopeType: reservation.ScopeType, + ScopeKey: reservation.ScopeKey, + ScopeName: reservation.ScopeName, + ScopeMetadata: reservation.ScopeMetadata, + Metric: reservation.Metric, + Limit: reservation.Limit, + Amount: reservation.Amount, + Current: active, + Used: active, + Projected: active + reservation.Amount, + WindowSeconds: reservation.WindowSeconds, + ResetAt: nextAvailableAt, + Policy: reservation.Policy, + Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit), + RetryAfter: concurrencyRetryAfter(nextAvailableAt), + Retryable: true, } } var leaseID string @@ -135,11 +156,13 @@ RETURNING window_start`, if err != nil { if errors.Is(err, pgx.ErrNoRows) { resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second) + currentUsed := 0.0 + currentReserved := 0.0 _ = tx.QueryRow(ctx, ` WITH bounds AS ( SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start ) -SELECT counters.reset_at +SELECT counters.used_value::float8, counters.reserved_value::float8, counters.reset_at FROM gateway_rate_limit_counters counters JOIN bounds ON counters.window_start = bounds.window_start WHERE scope_type = $1 @@ -149,12 +172,26 @@ WHERE scope_type = $1 reservation.ScopeKey, reservation.Metric, reservation.WindowSeconds, - ).Scan(&resetAt) + ).Scan(¤tUsed, ¤tReserved, &resetAt) + current := currentUsed + currentReserved return RateLimitReservation{}, &RateLimitExceededError{ - Metric: reservation.Metric, - Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric), - RetryAfter: retryAfterUntil(resetAt), - Retryable: true, + ScopeType: reservation.ScopeType, + ScopeKey: reservation.ScopeKey, + ScopeName: reservation.ScopeName, + ScopeMetadata: reservation.ScopeMetadata, + Metric: reservation.Metric, + Limit: reservation.Limit, + Amount: reservation.Amount, + Current: current, + Used: currentUsed, + Reserved: currentReserved, + Projected: current + reservation.Amount, + WindowSeconds: reservation.WindowSeconds, + ResetAt: resetAt, + Policy: reservation.Policy, + Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric), + RetryAfter: retryAfterUntil(resetAt), + Retryable: true, } } return RateLimitReservation{}, err diff --git a/apps/api/internal/store/runtime_types.go b/apps/api/internal/store/runtime_types.go index d2784a7..601be8e 100644 --- a/apps/api/internal/store/runtime_types.go +++ b/apps/api/internal/store/runtime_types.go @@ -33,10 +33,23 @@ func ModelCandidateErrorCode(err error) string { } type RateLimitExceededError struct { - Metric string - Message string - RetryAfter time.Duration - Retryable bool + ScopeType string + ScopeKey string + ScopeName string + ScopeMetadata map[string]any + Metric string + Limit float64 + Amount float64 + Current float64 + Used float64 + Reserved float64 + Projected float64 + WindowSeconds int + ResetAt time.Time + Policy map[string]any + Message string + RetryAfter time.Duration + Retryable bool } func (e *RateLimitExceededError) Error() string { @@ -166,12 +179,15 @@ type RateLimitReservation struct { ReservationID string ScopeType string ScopeKey string + ScopeName string + ScopeMetadata map[string]any Metric string Limit float64 Amount float64 WindowSeconds int LeaseTTLSeconds int WindowStart time.Time + Policy map[string]any } type RateLimitResult struct { diff --git a/apps/api/internal/store/user_group_policy.go b/apps/api/internal/store/user_group_policy.go index b3f019a..efd347b 100644 --- a/apps/api/internal/store/user_group_policy.go +++ b/apps/api/internal/store/user_group_policy.go @@ -10,6 +10,7 @@ import ( type UserGroupPolicy struct { ID string GroupKey string + Name string RateLimitPolicy map[string]any BillingDiscountPolicy map[string]any } @@ -23,12 +24,12 @@ func (s *Store) ResolveUserGroupPolicy(ctx context.Context, user *auth.User) (Us var rateLimit []byte var billing []byte err := s.pool.QueryRow(ctx, ` -SELECT id::text, group_key, rate_limit_policy, billing_discount_policy +SELECT id::text, group_key, name, rate_limit_policy, billing_discount_policy FROM gateway_user_groups WHERE status = 'active' AND (($1 <> '' AND id = NULLIF($1, '')::uuid) OR ($1 = '' AND group_key = 'default')) ORDER BY CASE WHEN id::text = $1 THEN 0 ELSE 1 END, priority ASC -LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &rateLimit, &billing) +LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &item.Name, &rateLimit, &billing) if err != nil { if err == pgx.ErrNoRows { return UserGroupPolicy{}, nil diff --git a/apps/web/src/api.ts b/apps/web/src/api.ts index b28b250..22abd77 100644 --- a/apps/web/src/api.ts +++ b/apps/web/src/api.ts @@ -19,6 +19,7 @@ import type { GatewayTenant, GatewayTenantUpsertRequest, GatewayNetworkProxyConfig, + GatewayPricingEstimate, GatewayTask, GatewayTaskParamPreprocessingLog, GatewayUser, @@ -786,8 +787,8 @@ export async function uploadFileToStorage( export async function estimatePricing( token: string, input: Record, -): Promise<{ items: unknown[]; resolver: string }> { - return request<{ items: unknown[]; resolver: string }>('/api/v1/pricing/estimate', { +): Promise { + return request('/api/v1/pricing/estimate', { body: input, method: 'POST', token, diff --git a/apps/web/src/pages/PlaygroundPage.tsx b/apps/web/src/pages/PlaygroundPage.tsx index 5a4b95e..5a2b7e8 100644 --- a/apps/web/src/pages/PlaygroundPage.tsx +++ b/apps/web/src/pages/PlaygroundPage.tsx @@ -1,8 +1,8 @@ import { useEffect, useMemo, useRef, useState } from 'react'; -import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts'; +import type { GatewayApiKey, GatewayPricingEstimate, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts'; import { ArrowUp, ChevronDown, MessageSquarePlus, Settings2, Sparkles } from 'lucide-react'; import { Badge, Button, FormDialog, Select, Textarea } from '../components/ui'; -import { createImageEditTask, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, resolveApiAssetUrl, taskIsPending } from '../api'; +import { createImageEditTask, createImageGenerationTask, createVideoGenerationTask, estimatePricing, pollTaskUntilSettled, resolveApiAssetUrl, taskIsPending } from '../api'; import type { PlaygroundMode } from '../types'; import { PlaygroundPromptMentionInput, @@ -57,6 +57,14 @@ import { const MEDIA_RUNS_STORAGE_KEY = 'easyai:playground:media-runs:v1'; const MEDIA_RUNS_STORAGE_LIMIT = 50; +type MediaEstimateState = { + amount?: number; + currency?: string; + error?: string; + resolver?: string; + status: 'idle' | 'loading' | 'ready' | 'error'; +}; + const publicWorks = [ { title: '雨夜霓虹街区', type: '图像生成', image: 'https://picsum.photos/seed/easyai-neon-city/720/960' }, { title: '玻璃温室晨光', type: '图像生成', image: 'https://picsum.photos/seed/easyai-glasshouse/720/540' }, @@ -94,6 +102,7 @@ export function PlaygroundPage(props: { const [mediaMessage, setMediaMessage] = useState(''); const [mediaUploadMessage, setMediaUploadMessage] = useState(''); const [mediaUploads, setMediaUploads] = useState([]); + const [mediaEstimate, setMediaEstimate] = useState({ status: 'idle' }); const [mediaUploading, setMediaUploading] = useState(false); const [settingsOpen, setSettingsOpen] = useState(false); const isMountedRef = useRef(false); @@ -118,6 +127,13 @@ export function PlaygroundPage(props: { : undefined, [activeModelOption, mediaSettings.resolution, props.mode, videoMode], ); + const mediaEstimatePayload = useMemo(() => { + if (props.mode === 'chat' || !selectedModel) return null; + const normalizedSettings = mediaCapabilities + ? normalizeMediaSettingsForCapabilities(mediaSettings, mediaCapabilities, props.mode) + : mediaSettings; + return buildMediaEstimatePayload(props.mode, selectedModel, prompt, normalizedSettings, mediaUploads, videoMode); + }, [mediaCapabilities, mediaSettings, mediaUploads, prompt, props.mode, selectedModel, videoMode]); useEffect(() => { setSelectedModel((current) => { @@ -155,6 +171,34 @@ export function PlaygroundPage(props: { setPrompt((current) => removeInvalidPlaygroundResourceTokens(current, mediaUploads)); }, [mediaUploadSignature, props.mode]); + useEffect(() => { + const credential = activeApiKeySecret || props.token; + if (props.mode === 'chat' || !credential || !mediaEstimatePayload) { + setMediaEstimate({ status: 'idle' }); + return; + } + let cancelled = false; + setMediaEstimate((current) => ({ ...current, error: '', status: 'loading' })); + const timer = window.setTimeout(() => { + estimatePricing(credential, mediaEstimatePayload) + .then((estimate) => { + if (cancelled) return; + setMediaEstimate(mediaEstimateFromResponse(estimate)); + }) + .catch((err) => { + if (cancelled) return; + setMediaEstimate({ + error: err instanceof Error ? err.message : '预计扣费计算失败', + status: 'error', + }); + }); + }, 260); + return () => { + cancelled = true; + window.clearTimeout(timer); + }; + }, [activeApiKeySecret, mediaEstimatePayload, props.mode, props.token]); + useEffect(() => { if (props.mode === 'image') { setMediaUploads((current) => current.some((item) => item.kind !== 'image') ? current.filter((item) => item.kind === 'image') : current); @@ -404,6 +448,7 @@ export function PlaygroundPage(props: { selectedModel={selectedModel} imageHasReference={effectiveImageHasReference} mediaSettings={mediaSettings} + mediaEstimate={mediaEstimate} mediaCapabilities={mediaCapabilities} uploadAccept={mediaUploadAcceptValue} uploadMessage={mediaUploadMessage} @@ -631,6 +676,7 @@ function Composer(props: { compact?: boolean; imageHasReference?: boolean; mediaCapabilities?: MediaModelCapabilities; + mediaEstimate?: MediaEstimateState; mediaSettings?: MediaGenerationSettings; mode: PlaygroundMode; modelOptions: ModelOption[]; @@ -727,10 +773,17 @@ function Composer(props: { onChange={props.onMediaSettingsChange} /> )} - - - 1 / 张 - + {props.mode !== 'chat' && props.mediaEstimate && ( + + + {mediaEstimateText(props.mediaEstimate)} + + )} @@ -739,6 +792,74 @@ function Composer(props: { ); } +function buildMediaEstimatePayload( + mode: Exclude, + model: string, + prompt: string, + settings: MediaGenerationSettings, + uploads: PlaygroundUpload[], + videoMode: VideoCreateMode, +): Record { + const requestPrompt = replacePlaygroundResourceTokens(prompt.trim(), uploads, mode); + if (mode === 'video') { + return { + kind: 'videos.generations', + model, + content: sharedVideoGenerationContentFromPromptAndUploads(requestPrompt, uploads, videoMode), + ...mediaRequestPayload(settings, 'video'), + }; + } + + const uploadPayload = sharedMediaUploadRequestPayload(uploads, 'image'); + return { + kind: uploads.some((item) => item.kind === 'image') ? 'images.edits' : 'images.generations', + model, + prompt: requestPrompt, + ...mediaRequestPayload(settings, 'image'), + ...uploadPayload, + }; +} + +function mediaEstimateFromResponse(response: GatewayPricingEstimate): MediaEstimateState { + return { + amount: numericFromUnknown(response.totalAmount) ?? estimateItemsTotal(response.items), + currency: response.currency || estimateItemsCurrency(response.items), + resolver: response.resolver, + status: 'ready', + }; +} + +function estimateItemsTotal(items: GatewayPricingEstimate['items']) { + const total = items.reduce((sum, item) => sum + (numericFromUnknown(item.amount) ?? 0), 0); + return Math.round(total * 1_000_000) / 1_000_000; +} + +function estimateItemsCurrency(items: GatewayPricingEstimate['items']) { + return items.find((item) => stringFromUnknown(item.currency))?.currency || 'resource'; +} + +function mediaEstimateText(estimate: MediaEstimateState) { + if (estimate.amount === undefined) return '--'; + return formatEstimateAmount(estimate.amount); +} + +function mediaEstimateHint(estimate: MediaEstimateState) { + if (estimate.status === 'error' && estimate.error) return `预计扣费计算失败:${estimate.error}`; + return '扣费为预计,实际扣费以账单为准'; +} + +function mediaEstimateAriaLabel(estimate: MediaEstimateState) { + if (estimate.status === 'error') return estimate.error ? `预计扣费计算失败:${estimate.error}` : '预计扣费计算失败'; + if (estimate.amount === undefined) return '预计扣费估算中'; + return `预计扣费 ${formatEstimateAmount(estimate.amount)} ${estimate.currency || 'resource'}`; +} + +function formatEstimateAmount(value: number) { + if (!Number.isFinite(value)) return '--'; + const digits = Math.abs(value) > 0 && Math.abs(value) < 0.01 ? 6 : 2; + return value.toFixed(digits).replace(/\.?0+$/, ''); +} + function mediaPromptPlaceholder(mode: PlaygroundMode) { if (mode === 'image') return '输入画面描述,可用 @ 或 @资产 快速引用图片资源,例如:让 @图像 1 保持人物一致...'; if (mode === 'video') return '输入镜头、运动和风格,可用 @ 或 @资产 引用图片、视频或音频资源...'; diff --git a/apps/web/src/pages/admin/IdentityManagementPanels.tsx b/apps/web/src/pages/admin/IdentityManagementPanels.tsx index a686f8d..01a65fd 100644 --- a/apps/web/src/pages/admin/IdentityManagementPanels.tsx +++ b/apps/web/src/pages/admin/IdentityManagementPanels.tsx @@ -78,8 +78,10 @@ type UserGroupForm = { description: string; source: string; priority: string; - rechargeDiscountPolicyJson: string; - billingDiscountPolicyJson: string; + rechargeDiscountFactor: string; + rechargeDiscountPolicy: Record; + billingDiscountFactor: string; + billingDiscountPolicy: Record; rateLimitPolicyJson: string; quotaPolicyJson: string; metadataJson: string; @@ -516,8 +518,8 @@ export function UserGroupsPanel(props: IdentityPanelProps) { - setForm({ ...form, rechargeDiscountPolicyJson: value })} /> - setForm({ ...form, billingDiscountPolicyJson: value })} /> + + setForm({ ...form, rateLimitPolicyJson: value })} /> setForm({ ...form, quotaPolicyJson: value })} /> setForm({ ...form, metadataJson: value })} /> @@ -769,8 +771,10 @@ function defaultUserGroupForm(): UserGroupForm { description: '', source: 'gateway', priority: '100', - rechargeDiscountPolicyJson: '{}', - billingDiscountPolicyJson: '{}', + rechargeDiscountFactor: '1', + rechargeDiscountPolicy: {}, + billingDiscountFactor: '1', + billingDiscountPolicy: {}, rateLimitPolicyJson: '{"rules":[]}', quotaPolicyJson: '{}', metadataJson: '{}', @@ -785,8 +789,10 @@ function userGroupToForm(group: UserGroup): UserGroupForm { description: group.description ?? '', source: group.source, priority: String(group.priority), - rechargeDiscountPolicyJson: stringifyJson(group.rechargeDiscountPolicy), - billingDiscountPolicyJson: stringifyJson(group.billingDiscountPolicy), + rechargeDiscountFactor: discountFactorText(group.rechargeDiscountPolicy), + rechargeDiscountPolicy: group.rechargeDiscountPolicy ?? {}, + billingDiscountFactor: discountFactorText(group.billingDiscountPolicy), + billingDiscountPolicy: group.billingDiscountPolicy ?? {}, rateLimitPolicyJson: stringifyJson(group.rateLimitPolicy), quotaPolicyJson: stringifyJson(group.quotaPolicy), metadataJson: stringifyJson(group.metadata), @@ -801,8 +807,8 @@ function formToUserGroupPayload(form: UserGroupForm): UserGroupUpsertRequest { description: form.description.trim() || undefined, source: form.source, priority: Number(form.priority) || 100, - rechargeDiscountPolicy: parseJsonObject(form.rechargeDiscountPolicyJson, '充值折扣策略 JSON'), - billingDiscountPolicy: parseJsonObject(form.billingDiscountPolicyJson, '计费折扣策略 JSON'), + rechargeDiscountPolicy: discountPolicyPayload(form.rechargeDiscountPolicy, form.rechargeDiscountFactor, '充值折扣系数'), + billingDiscountPolicy: discountPolicyPayload(form.billingDiscountPolicy, form.billingDiscountFactor, '计费折扣系数'), rateLimitPolicy: parseJsonObject(form.rateLimitPolicyJson, '限流策略 JSON'), quotaPolicy: parseJsonObject(form.quotaPolicyJson, '额度策略 JSON'), metadata: parseJsonObject(form.metadataJson, '元数据 JSON'), @@ -854,14 +860,57 @@ function newIdempotencyKey() { } function discountSummary(group: UserGroup) { - const billing = group.billingDiscountPolicy?.discountFactor ?? group.billingDiscountPolicy?.factor; - const recharge = group.rechargeDiscountPolicy?.discountFactor ?? group.rechargeDiscountPolicy?.factor; + const billing = discountFactorFromPolicy(group.billingDiscountPolicy); + const recharge = discountFactorFromPolicy(group.rechargeDiscountPolicy); const parts = []; - if (billing) parts.push(`计费 ${billing}`); - if (recharge) parts.push(`充值 ${recharge}`); + if (billing) parts.push(`计费 ${trimNumber(billing)}`); + if (recharge) parts.push(`充值 ${trimNumber(recharge)}`); return parts.join(' / ') || '未设置'; } +function discountFactorText(policy?: Record) { + const value = discountFactorFromPolicy(policy); + return value ? trimNumber(value) : '1'; +} + +function discountFactorFromPolicy(policy?: Record) { + return numberFromUnknown(policy?.discountFactor) ?? numberFromUnknown(policy?.factor); +} + +function discountPolicyPayload(basePolicy: Record, discountText: string, label: string) { + const policy = { ...basePolicy }; + delete policy.discountFactor; + delete policy.factor; + const discount = optionalPositiveNumber(discountText, label); + if (discount && discount !== 1) { + policy.discountFactor = discount; + } + return Object.keys(policy).length ? policy : undefined; +} + +function optionalPositiveNumber(value: string, label: string) { + const text = value.trim(); + if (!text) return undefined; + const parsed = Number(text); + if (!Number.isFinite(parsed) || parsed <= 0) { + throw new Error(`${label} 必须是大于 0 的数字`); + } + return parsed; +} + +function numberFromUnknown(value: unknown) { + if (typeof value === 'number' && Number.isFinite(value) && value > 0) return value; + if (typeof value === 'string' && value.trim()) { + const parsed = Number(value); + if (Number.isFinite(parsed) && parsed > 0) return parsed; + } + return undefined; +} + +function trimNumber(value: number) { + return value.toFixed(6).replace(/\.?0+$/, ''); +} + function policyKeys(value?: Record) { if (!value) return []; return Object.keys(value).slice(0, 3); diff --git a/apps/web/src/styles/playground.css b/apps/web/src/styles/playground.css index db8bad2..e6a9f1e 100644 --- a/apps/web/src/styles/playground.css +++ b/apps/web/src/styles/playground.css @@ -630,6 +630,15 @@ color: #526170; } +.composerEstimatedCharge[data-state="loading"] { + color: #6f7c8a; +} + +.composerEstimatedCharge[data-state="error"], +.composerEstimatedCharge[data-state="error"] svg { + color: #b45309; +} + .composerMediaSendButton.shButton { width: 40px; height: 40px; diff --git a/packages/contracts/src/index.ts b/packages/contracts/src/index.ts index a157b55..c3d5ece 100644 --- a/packages/contracts/src/index.ts +++ b/packages/contracts/src/index.ts @@ -506,6 +506,32 @@ export interface PlayableGatewayApiKey extends GatewayApiKey { secret: string; } +export interface GatewayPricingEstimateItem { + amount?: number; + currency?: string; + discountFactor?: number; + durationSeconds?: number; + durationUnit?: string; + durationUnitCount?: number; + model?: string; + modelAlias?: string; + platformId?: string; + platformModelId?: string; + provider?: string; + quantity?: number | string; + resourceType?: string; + simulated?: boolean; + unit?: string; + [key: string]: unknown; +} + +export interface GatewayPricingEstimate { + items: GatewayPricingEstimateItem[]; + resolver: string; + totalAmount?: number; + currency?: string; +} + export interface GatewayWalletAccount { id: string; gatewayTenantId?: string;