feat: add gateway billing estimate and rate limit details
This commit is contained in:
parent
bdc9be63d5
commit
37d0f919e5
@ -597,7 +597,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
status := statusFromRunError(runErr)
|
||||
errorPayload := map[string]any{
|
||||
"code": runErrorCode(runErr),
|
||||
"message": runErr.Error(),
|
||||
"message": runErrorMessage(runErr),
|
||||
"status": status,
|
||||
}
|
||||
if result.Task.ID != "" {
|
||||
@ -606,6 +606,9 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
if result.Task.RequestID != "" {
|
||||
errorPayload["requestId"] = result.Task.RequestID
|
||||
}
|
||||
for key, value := range runErrorDetails(runErr) {
|
||||
errorPayload[key] = value
|
||||
}
|
||||
sendSSE(w, "error", map[string]any{"error": errorPayload})
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
@ -626,7 +629,7 @@ func (s *Server) createTask(kind string, compatible bool) http.Handler {
|
||||
if !requestStillConnected(r) {
|
||||
return
|
||||
}
|
||||
writeError(w, statusFromRunError(runErr), runErr.Error(), runErrorCode(runErr))
|
||||
writeErrorWithDetails(w, statusFromRunError(runErr), runErrorMessage(runErr), runErrorDetails(runErr), runErrorCode(runErr))
|
||||
return
|
||||
}
|
||||
if !requestStillConnected(r) {
|
||||
@ -742,6 +745,138 @@ func runErrorCode(err error) string {
|
||||
return clients.ErrorCode(err)
|
||||
}
|
||||
|
||||
func runErrorMessage(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
if summary := rateLimitErrorSummary(err); summary != "" {
|
||||
return err.Error() + ";" + summary
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func runErrorDetails(err error) map[string]any {
|
||||
if detail := rateLimitErrorDetail(err); len(detail) > 0 {
|
||||
return map[string]any{"rateLimit": detail}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rateLimitErrorSummary(err error) string {
|
||||
var limitErr *store.RateLimitExceededError
|
||||
if !errors.As(err, &limitErr) {
|
||||
return ""
|
||||
}
|
||||
scopeLabel := "限流对象"
|
||||
switch limitErr.ScopeType {
|
||||
case "user_group":
|
||||
scopeLabel = "用户组"
|
||||
case "platform_model":
|
||||
scopeLabel = "平台模型"
|
||||
}
|
||||
scopeName := strings.TrimSpace(limitErr.ScopeName)
|
||||
if scopeName == "" {
|
||||
scopeName = strings.TrimSpace(limitErr.ScopeKey)
|
||||
}
|
||||
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); limitErr.ScopeType == "user_group" && groupKey != "" && groupKey != scopeName {
|
||||
scopeName = fmt.Sprintf("%s(%s)", scopeName, groupKey)
|
||||
}
|
||||
projected := limitErr.Projected
|
||||
if projected <= 0 {
|
||||
projected = limitErr.Current + limitErr.Amount
|
||||
}
|
||||
parts := []string{
|
||||
fmt.Sprintf("限流摘要:%s %s 的 %s 超限", scopeLabel, scopeName, limitErr.Metric),
|
||||
fmt.Sprintf("当前 %s,本次 %s,预计 %s,限制 %s", formatRateLimitValue(limitErr.Current), formatRateLimitValue(limitErr.Amount), formatRateLimitValue(projected), formatRateLimitValue(limitErr.Limit)),
|
||||
}
|
||||
if limitErr.WindowSeconds > 0 {
|
||||
parts = append(parts, fmt.Sprintf("窗口 %d 秒", limitErr.WindowSeconds))
|
||||
}
|
||||
if limitErr.RetryAfter > 0 {
|
||||
parts = append(parts, fmt.Sprintf("约%s后可重试", formatRateLimitDuration(limitErr.RetryAfter)))
|
||||
} else if !limitErr.Retryable {
|
||||
parts = append(parts, "该请求超过单次限额,不能排队重试")
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func rateLimitErrorDetail(err error) map[string]any {
|
||||
var limitErr *store.RateLimitExceededError
|
||||
if !errors.As(err, &limitErr) {
|
||||
return nil
|
||||
}
|
||||
detail := map[string]any{
|
||||
"scopeType": limitErr.ScopeType,
|
||||
"scopeKey": limitErr.ScopeKey,
|
||||
"scopeName": limitErr.ScopeName,
|
||||
"metric": limitErr.Metric,
|
||||
"limit": limitErr.Limit,
|
||||
"amount": limitErr.Amount,
|
||||
"current": limitErr.Current,
|
||||
"used": limitErr.Used,
|
||||
"reserved": limitErr.Reserved,
|
||||
"projected": limitErr.Projected,
|
||||
"windowSeconds": limitErr.WindowSeconds,
|
||||
"retryable": limitErr.Retryable,
|
||||
"exceeded": map[string]any{
|
||||
"metric": limitErr.Metric,
|
||||
"current": limitErr.Current,
|
||||
"amount": limitErr.Amount,
|
||||
"projected": limitErr.Projected,
|
||||
"limit": limitErr.Limit,
|
||||
},
|
||||
}
|
||||
if limitErr.RetryAfter > 0 {
|
||||
detail["retryAfterMs"] = limitErr.RetryAfter.Milliseconds()
|
||||
}
|
||||
if !limitErr.ResetAt.IsZero() {
|
||||
detail["resetAt"] = limitErr.ResetAt.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
if len(limitErr.Policy) > 0 {
|
||||
detail["rateLimitPolicy"] = limitErr.Policy
|
||||
if matchedRule := matchedRateLimitRule(limitErr.Policy, limitErr.Metric); len(matchedRule) > 0 {
|
||||
detail["matchedRule"] = matchedRule
|
||||
}
|
||||
}
|
||||
if len(limitErr.ScopeMetadata) > 0 {
|
||||
detail["scopeMetadata"] = limitErr.ScopeMetadata
|
||||
}
|
||||
if limitErr.ScopeType == "user_group" {
|
||||
userGroup := map[string]any{
|
||||
"id": limitErr.ScopeKey,
|
||||
"name": limitErr.ScopeName,
|
||||
}
|
||||
if groupKey := stringValue(limitErr.ScopeMetadata["groupKey"]); groupKey != "" {
|
||||
userGroup["groupKey"] = groupKey
|
||||
}
|
||||
detail["userGroup"] = userGroup
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func formatRateLimitValue(value float64) string {
|
||||
return strconv.FormatFloat(value, 'f', -1, 64)
|
||||
}
|
||||
|
||||
func formatRateLimitDuration(duration time.Duration) string {
|
||||
if duration < time.Second {
|
||||
return strconv.FormatInt(duration.Milliseconds(), 10) + "毫秒"
|
||||
}
|
||||
seconds := duration.Seconds()
|
||||
return strconv.FormatFloat(seconds, 'f', -1, 64) + "秒"
|
||||
}
|
||||
|
||||
func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
|
||||
rules, _ := policy["rules"].([]any)
|
||||
for _, rawRule := range rules {
|
||||
rule, _ := rawRule.(map[string]any)
|
||||
if stringValue(rule["metric"]) == metric {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := auth.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
|
||||
72
apps/api/internal/httpapi/rate_limit_error_detail_test.go
Normal file
72
apps/api/internal/httpapi/rate_limit_error_detail_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestRateLimitErrorDetailIncludesUserGroupAndExceededMetric(t *testing.T) {
|
||||
resetAt := time.Date(2026, 5, 15, 10, 30, 0, 0, time.UTC)
|
||||
detail := rateLimitErrorDetail(&store.RateLimitExceededError{
|
||||
ScopeType: "user_group",
|
||||
ScopeKey: "group-1",
|
||||
ScopeName: "VIP 用户组",
|
||||
ScopeMetadata: map[string]any{"groupKey": "vip"},
|
||||
Metric: "rpm",
|
||||
Limit: 2,
|
||||
Amount: 1,
|
||||
Current: 2,
|
||||
Used: 1,
|
||||
Reserved: 1,
|
||||
Projected: 3,
|
||||
WindowSeconds: 60,
|
||||
ResetAt: resetAt,
|
||||
RetryAfter: 5 * time.Second,
|
||||
Retryable: true,
|
||||
Policy: map[string]any{
|
||||
"rules": []any{
|
||||
map[string]any{"metric": "rpm", "limit": float64(2), "windowSeconds": float64(60)},
|
||||
},
|
||||
},
|
||||
})
|
||||
if detail["metric"] != "rpm" || detail["projected"] != float64(3) || detail["limit"] != float64(2) {
|
||||
t.Fatalf("unexpected exceeded detail: %+v", detail)
|
||||
}
|
||||
userGroup, _ := detail["userGroup"].(map[string]any)
|
||||
if userGroup["id"] != "group-1" || userGroup["groupKey"] != "vip" || userGroup["name"] != "VIP 用户组" {
|
||||
t.Fatalf("missing user group detail: %+v", detail)
|
||||
}
|
||||
matchedRule, _ := detail["matchedRule"].(map[string]any)
|
||||
if matchedRule["metric"] != "rpm" {
|
||||
t.Fatalf("missing matched rule: %+v", detail)
|
||||
}
|
||||
if detail["retryAfterMs"] != int64(5000) || detail["resetAt"] != resetAt.Format(time.RFC3339Nano) {
|
||||
t.Fatalf("missing retry/reset detail: %+v", detail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunErrorMessageIncludesRateLimitSummary(t *testing.T) {
|
||||
message := runErrorMessage(&store.RateLimitExceededError{
|
||||
ScopeType: "user_group",
|
||||
ScopeKey: "group-1",
|
||||
ScopeName: "VIP 用户组",
|
||||
ScopeMetadata: map[string]any{"groupKey": "vip"},
|
||||
Metric: "rpm",
|
||||
Limit: 2,
|
||||
Amount: 1,
|
||||
Current: 2,
|
||||
Projected: 3,
|
||||
WindowSeconds: 60,
|
||||
RetryAfter: 5 * time.Second,
|
||||
Retryable: true,
|
||||
Message: "rate limit exceeded: rpm window has no remaining capacity",
|
||||
})
|
||||
for _, expected := range []string{"限流摘要", "用户组 VIP 用户组(vip)", "rpm 超限", "当前 2", "本次 1", "预计 3", "限制 2", "窗口 60 秒", "约5秒后可重试"} {
|
||||
if !strings.Contains(message, expected) {
|
||||
t.Fatalf("message %q should contain %q", message, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -14,6 +14,10 @@ func writeJSON(w http.ResponseWriter, status int, value any) {
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, message string, codes ...string) {
|
||||
writeErrorWithDetails(w, status, message, nil, codes...)
|
||||
}
|
||||
|
||||
func writeErrorWithDetails(w http.ResponseWriter, status int, message string, details map[string]any, codes ...string) {
|
||||
errorPayload := map[string]any{
|
||||
"message": message,
|
||||
"status": status,
|
||||
@ -23,6 +27,9 @@ func writeError(w http.ResponseWriter, status int, message string, codes ...stri
|
||||
errorPayload["code"] = code
|
||||
}
|
||||
}
|
||||
for key, value := range details {
|
||||
errorPayload[key] = value
|
||||
}
|
||||
writeJSON(w, status, map[string]any{"error": errorPayload})
|
||||
}
|
||||
|
||||
|
||||
@ -52,9 +52,31 @@ func isLocalRateLimitError(err error) bool {
|
||||
|
||||
func (s *Service) rateLimitReservations(ctx context.Context, user *auth.User, candidate store.RuntimeModelCandidate, body map[string]any) []store.RateLimitReservation {
|
||||
out := make([]store.RateLimitReservation, 0)
|
||||
out = append(out, reservationsFromPolicy("platform_model", candidate.PlatformModelID, effectiveRateLimitPolicy(candidate), body)...)
|
||||
out = append(out, reservationsFromPolicy(
|
||||
"platform_model",
|
||||
candidate.PlatformModelID,
|
||||
firstNonEmptyString(candidate.DisplayName, candidate.ModelAlias, candidate.ModelName),
|
||||
map[string]any{
|
||||
"platformId": candidate.PlatformID,
|
||||
"platformName": candidate.PlatformName,
|
||||
"modelAlias": candidate.ModelAlias,
|
||||
"modelName": candidate.ModelName,
|
||||
},
|
||||
effectiveRateLimitPolicy(candidate),
|
||||
body,
|
||||
)...)
|
||||
if group, err := s.store.ResolveUserGroupPolicy(ctx, user); err == nil && group.ID != "" {
|
||||
out = append(out, reservationsFromPolicy("user_group", group.ID, group.RateLimitPolicy, body)...)
|
||||
out = append(out, reservationsFromPolicy(
|
||||
"user_group",
|
||||
group.ID,
|
||||
firstNonEmptyString(group.Name, group.GroupKey),
|
||||
map[string]any{
|
||||
"groupKey": group.GroupKey,
|
||||
"name": group.Name,
|
||||
},
|
||||
group.RateLimitPolicy,
|
||||
body,
|
||||
)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -90,7 +112,7 @@ func effectiveRetryPolicy(candidate store.RuntimeModelCandidate) map[string]any
|
||||
return policy
|
||||
}
|
||||
|
||||
func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string]any, body map[string]any) []store.RateLimitReservation {
|
||||
func reservationsFromPolicy(scopeType string, scopeKey string, scopeName string, scopeMetadata map[string]any, policy map[string]any, body map[string]any) []store.RateLimitReservation {
|
||||
if scopeKey == "" || !hasRules(policy) {
|
||||
return nil
|
||||
}
|
||||
@ -108,11 +130,14 @@ func reservationsFromPolicy(scopeType string, scopeKey string, policy map[string
|
||||
out = append(out, store.RateLimitReservation{
|
||||
ScopeType: scopeType,
|
||||
ScopeKey: scopeKey,
|
||||
ScopeName: scopeName,
|
||||
ScopeMetadata: scopeMetadata,
|
||||
Metric: metric,
|
||||
Limit: limit,
|
||||
Amount: amount,
|
||||
WindowSeconds: int(floatFromAny(rule["windowSeconds"])),
|
||||
LeaseTTLSeconds: int(floatFromAny(rule["leaseTtlSeconds"])),
|
||||
Policy: policy,
|
||||
})
|
||||
}
|
||||
return out
|
||||
|
||||
@ -11,8 +11,10 @@ import (
|
||||
)
|
||||
|
||||
type EstimateResult struct {
|
||||
Items []any `json:"items"`
|
||||
Resolver string `json:"resolver"`
|
||||
Items []any `json:"items"`
|
||||
Resolver string `json:"resolver"`
|
||||
TotalAmount float64 `json:"totalAmount"`
|
||||
Currency string `json:"currency"`
|
||||
}
|
||||
|
||||
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
|
||||
@ -23,9 +25,12 @@ func (s *Service) Estimate(ctx context.Context, kind string, model string, body
|
||||
}
|
||||
candidate := candidates[0]
|
||||
body = preprocessRequest(kind, body, candidate)
|
||||
items := s.estimatedBillings(ctx, user, kind, body, candidate)
|
||||
return EstimateResult{
|
||||
Items: s.estimatedBillings(ctx, user, kind, body, candidate),
|
||||
Resolver: "effective-pricing-v1",
|
||||
Items: items,
|
||||
Resolver: "effective-pricing-v1",
|
||||
TotalAmount: totalBillingAmount(items),
|
||||
Currency: billingCurrency(items),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -60,10 +65,7 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
|
||||
billingLine(candidate, "text_output", "1k_tokens", outputTokens, outputAmount, discount, simulated),
|
||||
}
|
||||
}
|
||||
count := int(floatFromAny(body["n"]))
|
||||
if count <= 0 {
|
||||
count = 1
|
||||
}
|
||||
count := requestOutputCount(body)
|
||||
resource := "image"
|
||||
unit := "image"
|
||||
baseKey := "imageBase"
|
||||
@ -73,8 +75,24 @@ func (s *Service) billings(ctx context.Context, user *auth.User, kind string, bo
|
||||
}
|
||||
if kind == "videos.generations" {
|
||||
resource = "video"
|
||||
unit = "video"
|
||||
unit = "5s_video"
|
||||
baseKey = "videoBase"
|
||||
duration := requestDurationSeconds(body)
|
||||
durationUnits := math.Max(1, math.Ceil(duration/5))
|
||||
amount := float64(count) *
|
||||
durationUnits *
|
||||
resourcePrice(config, resource, baseKey, "basePrice") *
|
||||
resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) *
|
||||
resourceWeight(config, resource, "audioWeights", boolWeightKey(boolishValue(body["audio"]))) *
|
||||
resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) *
|
||||
resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body))) *
|
||||
discount
|
||||
return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{
|
||||
"count": count,
|
||||
"durationSeconds": duration,
|
||||
"durationUnit": "5s",
|
||||
"durationUnitCount": durationUnits,
|
||||
})}
|
||||
}
|
||||
amount := float64(count) * resourcePrice(config, resource, baseKey, "basePrice") * resourceWeight(config, resource, "qualityWeights", stringFromMap(body, "quality")) * resourceWeight(config, resource, "sizeWeights", stringFromMap(body, "size")) * resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * discount
|
||||
return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)}
|
||||
@ -109,17 +127,23 @@ func effectiveDiscount(ctx context.Context, db *store.Store, user *auth.User, ca
|
||||
if discount <= 0 {
|
||||
discount = 1
|
||||
}
|
||||
if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil {
|
||||
groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"])
|
||||
if groupDiscount > 0 {
|
||||
discount *= groupDiscount
|
||||
if db != nil {
|
||||
if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil {
|
||||
groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"])
|
||||
if groupDiscount > 0 {
|
||||
discount *= groupDiscount
|
||||
}
|
||||
}
|
||||
}
|
||||
return discount
|
||||
}
|
||||
|
||||
func billingLine(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool) map[string]any {
|
||||
return map[string]any{
|
||||
return billingLineWithDetails(candidate, resourceType, unit, quantity, amount, discount, simulated, nil)
|
||||
}
|
||||
|
||||
func billingLineWithDetails(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool, details map[string]any) map[string]any {
|
||||
line := map[string]any{
|
||||
"model": candidate.ModelName,
|
||||
"modelAlias": candidate.ModelAlias,
|
||||
"provider": candidate.Provider,
|
||||
@ -133,6 +157,10 @@ func billingLine(candidate store.RuntimeModelCandidate, resourceType string, uni
|
||||
"discountFactor": discount,
|
||||
"simulated": simulated,
|
||||
}
|
||||
for key, value := range details {
|
||||
line[key] = value
|
||||
}
|
||||
return line
|
||||
}
|
||||
|
||||
func price(config map[string]any, key string) float64 {
|
||||
@ -177,7 +205,16 @@ func weighted(config map[string]any, key string, name string) float64 {
|
||||
}
|
||||
|
||||
func resourceWeight(config map[string]any, resource string, key string, name string) float64 {
|
||||
if value := weighted(config, key, name); value != 1 {
|
||||
keys := weightKeyAliases(key)
|
||||
names := weightValueAliases(key, name)
|
||||
for _, candidateKey := range keys {
|
||||
for _, candidateName := range names {
|
||||
if value := weighted(config, candidateKey, candidateName); value != 1 {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
if value := dynamicWeight(config["dynamicWeight"], keys, names); value != 1 {
|
||||
return value
|
||||
}
|
||||
if strings.TrimSpace(name) == "" {
|
||||
@ -187,19 +224,201 @@ func resourceWeight(config map[string]any, resource string, key string, name str
|
||||
if len(resourceConfig) == 0 && resource == "image_edit" {
|
||||
resourceConfig, _ = config["image"].(map[string]any)
|
||||
}
|
||||
if weights, ok := resourceConfig["dynamicWeight"].(map[string]any); ok {
|
||||
if value := floatFromAny(weights[name]); value > 0 {
|
||||
return value
|
||||
if value := dynamicWeight(resourceConfig["dynamicWeight"], keys, names); value != 1 {
|
||||
return value
|
||||
}
|
||||
for _, candidateKey := range keys {
|
||||
if weights, ok := resourceConfig[candidateKey].(map[string]any); ok {
|
||||
for _, candidateName := range names {
|
||||
if value := floatFromAny(weights[candidateName]); value > 0 {
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if weights, ok := resourceConfig[key].(map[string]any); ok {
|
||||
if value := floatFromAny(weights[name]); value > 0 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func dynamicWeight(value any, keys []string, names []string) float64 {
|
||||
if len(names) == 0 {
|
||||
return 1
|
||||
}
|
||||
weights, _ := value.(map[string]any)
|
||||
if len(weights) == 0 {
|
||||
return 1
|
||||
}
|
||||
for _, name := range names {
|
||||
if direct := floatFromAny(weights[name]); direct > 0 {
|
||||
return direct
|
||||
}
|
||||
}
|
||||
for _, key := range keys {
|
||||
if nested, ok := weights[key].(map[string]any); ok {
|
||||
for _, name := range names {
|
||||
if nestedValue := floatFromAny(nested[name]); nestedValue > 0 {
|
||||
return nestedValue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func weightKeyAliases(key string) []string {
|
||||
switch key {
|
||||
case "qualityWeights":
|
||||
return []string{"qualityWeights", "qualityFactors"}
|
||||
case "resolutionWeights":
|
||||
return []string{"resolutionWeights", "resolutionFactors"}
|
||||
case "audioWeights":
|
||||
return []string{"audioWeights", "audioFactors"}
|
||||
case "referenceVideoWeights":
|
||||
return []string{"referenceVideoWeights", "referenceVideoFactors"}
|
||||
case "voiceSpecifiedWeights":
|
||||
return []string{"voiceSpecifiedWeights", "voiceSpecifiedFactors"}
|
||||
default:
|
||||
return []string{key}
|
||||
}
|
||||
}
|
||||
|
||||
func weightValueAliases(key string, name string) []string {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
switch key {
|
||||
case "audioWeights":
|
||||
return []string{name, "audio-" + name}
|
||||
case "referenceVideoWeights":
|
||||
return []string{name, "reference-video-" + name}
|
||||
case "voiceSpecifiedWeights":
|
||||
return []string{name, "voice-specified-" + name}
|
||||
default:
|
||||
return []string{name}
|
||||
}
|
||||
}
|
||||
|
||||
func requestOutputCount(body map[string]any) int {
|
||||
for _, key := range []string{"n", "count", "batch_size", "batchSize"} {
|
||||
if value := int(math.Ceil(floatFromAny(body[key]))); value > 0 {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func requestDurationSeconds(body map[string]any) float64 {
|
||||
for _, key := range []string{"duration", "durationSeconds", "duration_seconds"} {
|
||||
if value := floatFromAny(body[key]); value > 0 {
|
||||
return value
|
||||
}
|
||||
}
|
||||
for _, value := range body {
|
||||
items, ok := value.([]any)
|
||||
if !ok || len(items) == 0 {
|
||||
continue
|
||||
}
|
||||
total := 0.0
|
||||
allDurationItems := true
|
||||
for _, item := range items {
|
||||
record, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
allDurationItems = false
|
||||
break
|
||||
}
|
||||
duration := floatFromAny(record["duration"])
|
||||
if duration <= 0 {
|
||||
allDurationItems = false
|
||||
break
|
||||
}
|
||||
total += duration
|
||||
}
|
||||
if allDurationItems && total > 0 {
|
||||
return total
|
||||
}
|
||||
}
|
||||
return 5
|
||||
}
|
||||
|
||||
func requestHasReferenceVideo(body map[string]any) bool {
|
||||
if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) {
|
||||
return true
|
||||
}
|
||||
if firstNonEmptyStringValue(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") != "" {
|
||||
return true
|
||||
}
|
||||
content, _ := body["content"].([]any)
|
||||
for _, item := range content {
|
||||
record, _ := item.(map[string]any)
|
||||
if len(record) == 0 {
|
||||
continue
|
||||
}
|
||||
itemType := strings.TrimSpace(stringFromAny(record["type"]))
|
||||
role := strings.TrimSpace(stringFromAny(record["role"]))
|
||||
if itemType == "video_url" || role == "video_feature" || role == "video_base" || role == "reference_video" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func requestHasVoiceID(body map[string]any) bool {
|
||||
return boolishValue(body["audio"]) && firstNonEmptyStringValue(body, "voice_id", "voiceId") != ""
|
||||
}
|
||||
|
||||
func boolWeightKey(value bool) string {
|
||||
if value {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
|
||||
func boolishValue(value any) bool {
|
||||
switch typed := value.(type) {
|
||||
case bool:
|
||||
return typed
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(typed)) {
|
||||
case "true", "1", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
case int:
|
||||
return typed != 0
|
||||
case int64:
|
||||
return typed != 0
|
||||
case float64:
|
||||
return typed != 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func hasNonEmptyArray(value any) bool {
|
||||
items, ok := value.([]any)
|
||||
return ok && len(items) > 0
|
||||
}
|
||||
|
||||
func totalBillingAmount(items []any) float64 {
|
||||
total := 0.0
|
||||
for _, raw := range items {
|
||||
line, _ := raw.(map[string]any)
|
||||
total += floatFromAny(line["amount"])
|
||||
}
|
||||
return roundPrice(total)
|
||||
}
|
||||
|
||||
func billingCurrency(items []any) string {
|
||||
for _, raw := range items {
|
||||
line, _ := raw.(map[string]any)
|
||||
if currency := stringFromAny(line["currency"]); currency != "" {
|
||||
return currency
|
||||
}
|
||||
}
|
||||
return "resource"
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
|
||||
132
apps/api/internal/runner/pricing_test.go
Normal file
132
apps/api/internal/runner/pricing_test.go
Normal file
@ -0,0 +1,132 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
)
|
||||
|
||||
func TestImageBillingEstimateUsesCountResolutionAndQuality(t *testing.T) {
|
||||
service := &Service{}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelName: "image-model",
|
||||
BaseBillingConfig: map[string]any{
|
||||
"image": map[string]any{
|
||||
"basePrice": 10,
|
||||
"dynamicWeight": map[string]any{
|
||||
"resolutionFactors": map[string]any{"2K": 1.5},
|
||||
"qualityFactors": map[string]any{"high": 1.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
items := service.billings(context.Background(), nil, "images.generations", map[string]any{
|
||||
"count": 2,
|
||||
"quality": "high",
|
||||
"resolution": "2K",
|
||||
}, candidate, clients.Response{}, true)
|
||||
|
||||
line := firstBillingLine(t, items)
|
||||
if got, want := floatFromAny(line["amount"]), 45.0; got != want {
|
||||
t.Fatalf("image estimated amount = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := line["quantity"], 2; got != want {
|
||||
t.Fatalf("image quantity = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVideoBillingEstimateUsesFiveSecondUnitsAndDynamicWeights(t *testing.T) {
|
||||
service := &Service{}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelName: "video-model",
|
||||
BaseBillingConfig: map[string]any{
|
||||
"video": map[string]any{
|
||||
"basePrice": 100,
|
||||
"dynamicWeight": map[string]any{
|
||||
"resolutionWeights": map[string]any{"1080p": 1.5},
|
||||
"audioWeights": map[string]any{"true": 2},
|
||||
"referenceVideoWeights": map[string]any{"true": 1.5},
|
||||
"voiceSpecifiedWeights": map[string]any{"true": 1.2},
|
||||
"unusedCompatibilityField": map[string]any{"true": 99},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
items := service.billings(context.Background(), nil, "videos.generations", map[string]any{
|
||||
"audio": true,
|
||||
"duration": 12,
|
||||
"resolution": "1080p",
|
||||
"voice_id": "voice-a",
|
||||
"content": []any{
|
||||
map[string]any{"type": "video_url", "video_url": map[string]any{"url": "https://example.com/reference.mp4"}},
|
||||
},
|
||||
}, candidate, clients.Response{}, true)
|
||||
|
||||
line := firstBillingLine(t, items)
|
||||
if got, want := floatFromAny(line["amount"]), 1620.0; got != want {
|
||||
t.Fatalf("video estimated amount = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := floatFromAny(line["durationUnitCount"]), 3.0; got != want {
|
||||
t.Fatalf("video duration units = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := line["quantity"], 3; got != want {
|
||||
t.Fatalf("video quantity = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVideoBillingEstimateSupportsServerMainStyleDynamicKeys(t *testing.T) {
|
||||
service := &Service{}
|
||||
candidate := store.RuntimeModelCandidate{
|
||||
ModelName: "legacy-video-model",
|
||||
BaseBillingConfig: map[string]any{
|
||||
"videoBase": 100,
|
||||
"video": map[string]any{
|
||||
"dynamicWeight": map[string]any{
|
||||
"720p": 1.25,
|
||||
"audio-true": 2,
|
||||
"reference-video-true": 1.5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
items := service.billings(context.Background(), nil, "videos.generations", map[string]any{
|
||||
"audio": "true",
|
||||
"duration": 5,
|
||||
"resolution": "720p",
|
||||
"video_list": []any{map[string]any{"url": "https://example.com/reference.mp4"}},
|
||||
}, candidate, clients.Response{}, true)
|
||||
|
||||
line := firstBillingLine(t, items)
|
||||
if got, want := floatFromAny(line["amount"]), 375.0; got != want {
|
||||
t.Fatalf("legacy video estimated amount = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVideoDurationEstimateSumsMultiShotDurations(t *testing.T) {
|
||||
duration := requestDurationSeconds(map[string]any{
|
||||
"multi_prompt": []any{
|
||||
map[string]any{"prompt": "shot 1", "duration": 3},
|
||||
map[string]any{"prompt": "shot 2", "duration": 7},
|
||||
},
|
||||
})
|
||||
if duration != 10 {
|
||||
t.Fatalf("multi-shot duration = %v, want 10", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func firstBillingLine(t *testing.T, items []any) map[string]any {
|
||||
t.Helper()
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("items length = %d, want 1: %+v", len(items), items)
|
||||
}
|
||||
line, ok := items[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("item type = %T, want map[string]any", items[0])
|
||||
}
|
||||
return line
|
||||
}
|
||||
@ -31,9 +31,18 @@ func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID
|
||||
}
|
||||
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
|
||||
return RateLimitResult{}, &RateLimitExceededError{
|
||||
Metric: reservation.Metric,
|
||||
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
|
||||
Retryable: false,
|
||||
ScopeType: reservation.ScopeType,
|
||||
ScopeKey: reservation.ScopeKey,
|
||||
ScopeName: reservation.ScopeName,
|
||||
ScopeMetadata: reservation.ScopeMetadata,
|
||||
Metric: reservation.Metric,
|
||||
Limit: reservation.Limit,
|
||||
Amount: reservation.Amount,
|
||||
Projected: reservation.Amount,
|
||||
WindowSeconds: reservation.WindowSeconds,
|
||||
Policy: reservation.Policy,
|
||||
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
|
||||
Retryable: false,
|
||||
}
|
||||
}
|
||||
if reservation.WindowSeconds <= 0 {
|
||||
@ -78,10 +87,22 @@ WHERE scope_type = $1
|
||||
}
|
||||
if active+reservation.Amount > reservation.Limit {
|
||||
return "", &RateLimitExceededError{
|
||||
Metric: reservation.Metric,
|
||||
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
|
||||
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
|
||||
Retryable: true,
|
||||
ScopeType: reservation.ScopeType,
|
||||
ScopeKey: reservation.ScopeKey,
|
||||
ScopeName: reservation.ScopeName,
|
||||
ScopeMetadata: reservation.ScopeMetadata,
|
||||
Metric: reservation.Metric,
|
||||
Limit: reservation.Limit,
|
||||
Amount: reservation.Amount,
|
||||
Current: active,
|
||||
Used: active,
|
||||
Projected: active + reservation.Amount,
|
||||
WindowSeconds: reservation.WindowSeconds,
|
||||
ResetAt: nextAvailableAt,
|
||||
Policy: reservation.Policy,
|
||||
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
|
||||
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
var leaseID string
|
||||
@ -135,11 +156,13 @@ RETURNING window_start`,
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second)
|
||||
currentUsed := 0.0
|
||||
currentReserved := 0.0
|
||||
_ = tx.QueryRow(ctx, `
|
||||
WITH bounds AS (
|
||||
SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start
|
||||
)
|
||||
SELECT counters.reset_at
|
||||
SELECT counters.used_value::float8, counters.reserved_value::float8, counters.reset_at
|
||||
FROM gateway_rate_limit_counters counters
|
||||
JOIN bounds ON counters.window_start = bounds.window_start
|
||||
WHERE scope_type = $1
|
||||
@ -149,12 +172,26 @@ WHERE scope_type = $1
|
||||
reservation.ScopeKey,
|
||||
reservation.Metric,
|
||||
reservation.WindowSeconds,
|
||||
).Scan(&resetAt)
|
||||
).Scan(¤tUsed, ¤tReserved, &resetAt)
|
||||
current := currentUsed + currentReserved
|
||||
return RateLimitReservation{}, &RateLimitExceededError{
|
||||
Metric: reservation.Metric,
|
||||
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
|
||||
RetryAfter: retryAfterUntil(resetAt),
|
||||
Retryable: true,
|
||||
ScopeType: reservation.ScopeType,
|
||||
ScopeKey: reservation.ScopeKey,
|
||||
ScopeName: reservation.ScopeName,
|
||||
ScopeMetadata: reservation.ScopeMetadata,
|
||||
Metric: reservation.Metric,
|
||||
Limit: reservation.Limit,
|
||||
Amount: reservation.Amount,
|
||||
Current: current,
|
||||
Used: currentUsed,
|
||||
Reserved: currentReserved,
|
||||
Projected: current + reservation.Amount,
|
||||
WindowSeconds: reservation.WindowSeconds,
|
||||
ResetAt: resetAt,
|
||||
Policy: reservation.Policy,
|
||||
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
|
||||
RetryAfter: retryAfterUntil(resetAt),
|
||||
Retryable: true,
|
||||
}
|
||||
}
|
||||
return RateLimitReservation{}, err
|
||||
|
||||
@ -33,10 +33,23 @@ func ModelCandidateErrorCode(err error) string {
|
||||
}
|
||||
|
||||
type RateLimitExceededError struct {
|
||||
Metric string
|
||||
Message string
|
||||
RetryAfter time.Duration
|
||||
Retryable bool
|
||||
ScopeType string
|
||||
ScopeKey string
|
||||
ScopeName string
|
||||
ScopeMetadata map[string]any
|
||||
Metric string
|
||||
Limit float64
|
||||
Amount float64
|
||||
Current float64
|
||||
Used float64
|
||||
Reserved float64
|
||||
Projected float64
|
||||
WindowSeconds int
|
||||
ResetAt time.Time
|
||||
Policy map[string]any
|
||||
Message string
|
||||
RetryAfter time.Duration
|
||||
Retryable bool
|
||||
}
|
||||
|
||||
func (e *RateLimitExceededError) Error() string {
|
||||
@ -166,12 +179,15 @@ type RateLimitReservation struct {
|
||||
ReservationID string
|
||||
ScopeType string
|
||||
ScopeKey string
|
||||
ScopeName string
|
||||
ScopeMetadata map[string]any
|
||||
Metric string
|
||||
Limit float64
|
||||
Amount float64
|
||||
WindowSeconds int
|
||||
LeaseTTLSeconds int
|
||||
WindowStart time.Time
|
||||
Policy map[string]any
|
||||
}
|
||||
|
||||
type RateLimitResult struct {
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
type UserGroupPolicy struct {
|
||||
ID string
|
||||
GroupKey string
|
||||
Name string
|
||||
RateLimitPolicy map[string]any
|
||||
BillingDiscountPolicy map[string]any
|
||||
}
|
||||
@ -23,12 +24,12 @@ func (s *Store) ResolveUserGroupPolicy(ctx context.Context, user *auth.User) (Us
|
||||
var rateLimit []byte
|
||||
var billing []byte
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id::text, group_key, rate_limit_policy, billing_discount_policy
|
||||
SELECT id::text, group_key, name, rate_limit_policy, billing_discount_policy
|
||||
FROM gateway_user_groups
|
||||
WHERE status = 'active'
|
||||
AND (($1 <> '' AND id = NULLIF($1, '')::uuid) OR ($1 = '' AND group_key = 'default'))
|
||||
ORDER BY CASE WHEN id::text = $1 THEN 0 ELSE 1 END, priority ASC
|
||||
LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &rateLimit, &billing)
|
||||
LIMIT 1`, userGroupID).Scan(&item.ID, &item.GroupKey, &item.Name, &rateLimit, &billing)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return UserGroupPolicy{}, nil
|
||||
|
||||
@ -19,6 +19,7 @@ import type {
|
||||
GatewayTenant,
|
||||
GatewayTenantUpsertRequest,
|
||||
GatewayNetworkProxyConfig,
|
||||
GatewayPricingEstimate,
|
||||
GatewayTask,
|
||||
GatewayTaskParamPreprocessingLog,
|
||||
GatewayUser,
|
||||
@ -786,8 +787,8 @@ export async function uploadFileToStorage(
|
||||
export async function estimatePricing(
|
||||
token: string,
|
||||
input: Record<string, unknown>,
|
||||
): Promise<{ items: unknown[]; resolver: string }> {
|
||||
return request<{ items: unknown[]; resolver: string }>('/api/v1/pricing/estimate', {
|
||||
): Promise<GatewayPricingEstimate> {
|
||||
return request<GatewayPricingEstimate>('/api/v1/pricing/estimate', {
|
||||
body: input,
|
||||
method: 'POST',
|
||||
token,
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type { GatewayApiKey, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts';
|
||||
import type { GatewayApiKey, GatewayPricingEstimate, GatewayTask, PlatformModel } from '@easyai-ai-gateway/contracts';
|
||||
import { ArrowUp, ChevronDown, MessageSquarePlus, Settings2, Sparkles } from 'lucide-react';
|
||||
import { Badge, Button, FormDialog, Select, Textarea } from '../components/ui';
|
||||
import { createImageEditTask, createImageGenerationTask, createVideoGenerationTask, pollTaskUntilSettled, resolveApiAssetUrl, taskIsPending } from '../api';
|
||||
import { createImageEditTask, createImageGenerationTask, createVideoGenerationTask, estimatePricing, pollTaskUntilSettled, resolveApiAssetUrl, taskIsPending } from '../api';
|
||||
import type { PlaygroundMode } from '../types';
|
||||
import {
|
||||
PlaygroundPromptMentionInput,
|
||||
@ -57,6 +57,14 @@ import {
|
||||
const MEDIA_RUNS_STORAGE_KEY = 'easyai:playground:media-runs:v1';
|
||||
const MEDIA_RUNS_STORAGE_LIMIT = 50;
|
||||
|
||||
type MediaEstimateState = {
|
||||
amount?: number;
|
||||
currency?: string;
|
||||
error?: string;
|
||||
resolver?: string;
|
||||
status: 'idle' | 'loading' | 'ready' | 'error';
|
||||
};
|
||||
|
||||
const publicWorks = [
|
||||
{ title: '雨夜霓虹街区', type: '图像生成', image: 'https://picsum.photos/seed/easyai-neon-city/720/960' },
|
||||
{ title: '玻璃温室晨光', type: '图像生成', image: 'https://picsum.photos/seed/easyai-glasshouse/720/540' },
|
||||
@ -94,6 +102,7 @@ export function PlaygroundPage(props: {
|
||||
const [mediaMessage, setMediaMessage] = useState('');
|
||||
const [mediaUploadMessage, setMediaUploadMessage] = useState('');
|
||||
const [mediaUploads, setMediaUploads] = useState<PlaygroundUpload[]>([]);
|
||||
const [mediaEstimate, setMediaEstimate] = useState<MediaEstimateState>({ status: 'idle' });
|
||||
const [mediaUploading, setMediaUploading] = useState(false);
|
||||
const [settingsOpen, setSettingsOpen] = useState(false);
|
||||
const isMountedRef = useRef(false);
|
||||
@ -118,6 +127,13 @@ export function PlaygroundPage(props: {
|
||||
: undefined,
|
||||
[activeModelOption, mediaSettings.resolution, props.mode, videoMode],
|
||||
);
|
||||
const mediaEstimatePayload = useMemo(() => {
|
||||
if (props.mode === 'chat' || !selectedModel) return null;
|
||||
const normalizedSettings = mediaCapabilities
|
||||
? normalizeMediaSettingsForCapabilities(mediaSettings, mediaCapabilities, props.mode)
|
||||
: mediaSettings;
|
||||
return buildMediaEstimatePayload(props.mode, selectedModel, prompt, normalizedSettings, mediaUploads, videoMode);
|
||||
}, [mediaCapabilities, mediaSettings, mediaUploads, prompt, props.mode, selectedModel, videoMode]);
|
||||
|
||||
useEffect(() => {
|
||||
setSelectedModel((current) => {
|
||||
@ -155,6 +171,34 @@ export function PlaygroundPage(props: {
|
||||
setPrompt((current) => removeInvalidPlaygroundResourceTokens(current, mediaUploads));
|
||||
}, [mediaUploadSignature, props.mode]);
|
||||
|
||||
useEffect(() => {
|
||||
const credential = activeApiKeySecret || props.token;
|
||||
if (props.mode === 'chat' || !credential || !mediaEstimatePayload) {
|
||||
setMediaEstimate({ status: 'idle' });
|
||||
return;
|
||||
}
|
||||
let cancelled = false;
|
||||
setMediaEstimate((current) => ({ ...current, error: '', status: 'loading' }));
|
||||
const timer = window.setTimeout(() => {
|
||||
estimatePricing(credential, mediaEstimatePayload)
|
||||
.then((estimate) => {
|
||||
if (cancelled) return;
|
||||
setMediaEstimate(mediaEstimateFromResponse(estimate));
|
||||
})
|
||||
.catch((err) => {
|
||||
if (cancelled) return;
|
||||
setMediaEstimate({
|
||||
error: err instanceof Error ? err.message : '预计扣费计算失败',
|
||||
status: 'error',
|
||||
});
|
||||
});
|
||||
}, 260);
|
||||
return () => {
|
||||
cancelled = true;
|
||||
window.clearTimeout(timer);
|
||||
};
|
||||
}, [activeApiKeySecret, mediaEstimatePayload, props.mode, props.token]);
|
||||
|
||||
useEffect(() => {
|
||||
if (props.mode === 'image') {
|
||||
setMediaUploads((current) => current.some((item) => item.kind !== 'image') ? current.filter((item) => item.kind === 'image') : current);
|
||||
@ -404,6 +448,7 @@ export function PlaygroundPage(props: {
|
||||
selectedModel={selectedModel}
|
||||
imageHasReference={effectiveImageHasReference}
|
||||
mediaSettings={mediaSettings}
|
||||
mediaEstimate={mediaEstimate}
|
||||
mediaCapabilities={mediaCapabilities}
|
||||
uploadAccept={mediaUploadAcceptValue}
|
||||
uploadMessage={mediaUploadMessage}
|
||||
@ -631,6 +676,7 @@ function Composer(props: {
|
||||
compact?: boolean;
|
||||
imageHasReference?: boolean;
|
||||
mediaCapabilities?: MediaModelCapabilities;
|
||||
mediaEstimate?: MediaEstimateState;
|
||||
mediaSettings?: MediaGenerationSettings;
|
||||
mode: PlaygroundMode;
|
||||
modelOptions: ModelOption[];
|
||||
@ -727,10 +773,17 @@ function Composer(props: {
|
||||
onChange={props.onMediaSettingsChange}
|
||||
/>
|
||||
)}
|
||||
<span className="composerEstimatedCharge" aria-label="预计扣费 1 / 张">
|
||||
<Sparkles size={14} />
|
||||
<span>1 / 张</span>
|
||||
</span>
|
||||
{props.mode !== 'chat' && props.mediaEstimate && (
|
||||
<span
|
||||
className="composerEstimatedCharge"
|
||||
data-state={props.mediaEstimate.status}
|
||||
title={mediaEstimateHint(props.mediaEstimate)}
|
||||
aria-label={mediaEstimateAriaLabel(props.mediaEstimate)}
|
||||
>
|
||||
<Sparkles size={14} />
|
||||
<span>{mediaEstimateText(props.mediaEstimate)}</span>
|
||||
</span>
|
||||
)}
|
||||
<Button type="button" size="icon" className="composerMediaSendButton" aria-label="发送测试" onClick={props.onSubmit}>
|
||||
<ArrowUp size={24} />
|
||||
</Button>
|
||||
@ -739,6 +792,74 @@ function Composer(props: {
|
||||
);
|
||||
}
|
||||
|
||||
function buildMediaEstimatePayload(
|
||||
mode: Exclude<PlaygroundMode, 'chat'>,
|
||||
model: string,
|
||||
prompt: string,
|
||||
settings: MediaGenerationSettings,
|
||||
uploads: PlaygroundUpload[],
|
||||
videoMode: VideoCreateMode,
|
||||
): Record<string, unknown> {
|
||||
const requestPrompt = replacePlaygroundResourceTokens(prompt.trim(), uploads, mode);
|
||||
if (mode === 'video') {
|
||||
return {
|
||||
kind: 'videos.generations',
|
||||
model,
|
||||
content: sharedVideoGenerationContentFromPromptAndUploads(requestPrompt, uploads, videoMode),
|
||||
...mediaRequestPayload(settings, 'video'),
|
||||
};
|
||||
}
|
||||
|
||||
const uploadPayload = sharedMediaUploadRequestPayload(uploads, 'image');
|
||||
return {
|
||||
kind: uploads.some((item) => item.kind === 'image') ? 'images.edits' : 'images.generations',
|
||||
model,
|
||||
prompt: requestPrompt,
|
||||
...mediaRequestPayload(settings, 'image'),
|
||||
...uploadPayload,
|
||||
};
|
||||
}
|
||||
|
||||
function mediaEstimateFromResponse(response: GatewayPricingEstimate): MediaEstimateState {
|
||||
return {
|
||||
amount: numericFromUnknown(response.totalAmount) ?? estimateItemsTotal(response.items),
|
||||
currency: response.currency || estimateItemsCurrency(response.items),
|
||||
resolver: response.resolver,
|
||||
status: 'ready',
|
||||
};
|
||||
}
|
||||
|
||||
function estimateItemsTotal(items: GatewayPricingEstimate['items']) {
|
||||
const total = items.reduce((sum, item) => sum + (numericFromUnknown(item.amount) ?? 0), 0);
|
||||
return Math.round(total * 1_000_000) / 1_000_000;
|
||||
}
|
||||
|
||||
function estimateItemsCurrency(items: GatewayPricingEstimate['items']) {
|
||||
return items.find((item) => stringFromUnknown(item.currency))?.currency || 'resource';
|
||||
}
|
||||
|
||||
function mediaEstimateText(estimate: MediaEstimateState) {
|
||||
if (estimate.amount === undefined) return '--';
|
||||
return formatEstimateAmount(estimate.amount);
|
||||
}
|
||||
|
||||
function mediaEstimateHint(estimate: MediaEstimateState) {
|
||||
if (estimate.status === 'error' && estimate.error) return `预计扣费计算失败:${estimate.error}`;
|
||||
return '扣费为预计,实际扣费以账单为准';
|
||||
}
|
||||
|
||||
function mediaEstimateAriaLabel(estimate: MediaEstimateState) {
|
||||
if (estimate.status === 'error') return estimate.error ? `预计扣费计算失败:${estimate.error}` : '预计扣费计算失败';
|
||||
if (estimate.amount === undefined) return '预计扣费估算中';
|
||||
return `预计扣费 ${formatEstimateAmount(estimate.amount)} ${estimate.currency || 'resource'}`;
|
||||
}
|
||||
|
||||
function formatEstimateAmount(value: number) {
|
||||
if (!Number.isFinite(value)) return '--';
|
||||
const digits = Math.abs(value) > 0 && Math.abs(value) < 0.01 ? 6 : 2;
|
||||
return value.toFixed(digits).replace(/\.?0+$/, '');
|
||||
}
|
||||
|
||||
function mediaPromptPlaceholder(mode: PlaygroundMode) {
|
||||
if (mode === 'image') return '输入画面描述,可用 @ 或 @资产 快速引用图片资源,例如:让 @图像 1 保持人物一致...';
|
||||
if (mode === 'video') return '输入镜头、运动和风格,可用 @ 或 @资产 引用图片、视频或音频资源...';
|
||||
|
||||
@ -78,8 +78,10 @@ type UserGroupForm = {
|
||||
description: string;
|
||||
source: string;
|
||||
priority: string;
|
||||
rechargeDiscountPolicyJson: string;
|
||||
billingDiscountPolicyJson: string;
|
||||
rechargeDiscountFactor: string;
|
||||
rechargeDiscountPolicy: Record<string, unknown>;
|
||||
billingDiscountFactor: string;
|
||||
billingDiscountPolicy: Record<string, unknown>;
|
||||
rateLimitPolicyJson: string;
|
||||
quotaPolicyJson: string;
|
||||
metadataJson: string;
|
||||
@ -516,8 +518,8 @@ export function UserGroupsPanel(props: IdentityPanelProps) {
|
||||
<Label>状态<Select size="sm" value={form.status} onChange={(event) => setForm({ ...form, status: event.target.value })}>{userGroupStatuses.map(option)}</Select></Label>
|
||||
<Label>优先级<Input size="sm" value={form.priority} inputMode="numeric" onChange={(event) => setForm({ ...form, priority: event.target.value })} /></Label>
|
||||
<Label className="spanTwo">描述<Input size="sm" value={form.description} onChange={(event) => setForm({ ...form, description: event.target.value })} /></Label>
|
||||
<JsonField label="充值折扣策略 JSON" value={form.rechargeDiscountPolicyJson} onChange={(value) => setForm({ ...form, rechargeDiscountPolicyJson: value })} />
|
||||
<JsonField label="计费折扣策略 JSON" value={form.billingDiscountPolicyJson} onChange={(value) => setForm({ ...form, billingDiscountPolicyJson: value })} />
|
||||
<Label>充值折扣系数<Input size="sm" value={form.rechargeDiscountFactor} inputMode="decimal" placeholder="1 = 不打折,0.95 = 95 折" onChange={(event) => setForm({ ...form, rechargeDiscountFactor: event.target.value })} /></Label>
|
||||
<Label>计费折扣系数<Input size="sm" value={form.billingDiscountFactor} inputMode="decimal" placeholder="1 = 不打折,0.95 = 95 折" onChange={(event) => setForm({ ...form, billingDiscountFactor: event.target.value })} /></Label>
|
||||
<JsonField label="限流策略 JSON" value={form.rateLimitPolicyJson} onChange={(value) => setForm({ ...form, rateLimitPolicyJson: value })} />
|
||||
<JsonField label="额度策略 JSON" value={form.quotaPolicyJson} onChange={(value) => setForm({ ...form, quotaPolicyJson: value })} />
|
||||
<JsonField label="元数据 JSON" value={form.metadataJson} onChange={(value) => setForm({ ...form, metadataJson: value })} />
|
||||
@ -769,8 +771,10 @@ function defaultUserGroupForm(): UserGroupForm {
|
||||
description: '',
|
||||
source: 'gateway',
|
||||
priority: '100',
|
||||
rechargeDiscountPolicyJson: '{}',
|
||||
billingDiscountPolicyJson: '{}',
|
||||
rechargeDiscountFactor: '1',
|
||||
rechargeDiscountPolicy: {},
|
||||
billingDiscountFactor: '1',
|
||||
billingDiscountPolicy: {},
|
||||
rateLimitPolicyJson: '{"rules":[]}',
|
||||
quotaPolicyJson: '{}',
|
||||
metadataJson: '{}',
|
||||
@ -785,8 +789,10 @@ function userGroupToForm(group: UserGroup): UserGroupForm {
|
||||
description: group.description ?? '',
|
||||
source: group.source,
|
||||
priority: String(group.priority),
|
||||
rechargeDiscountPolicyJson: stringifyJson(group.rechargeDiscountPolicy),
|
||||
billingDiscountPolicyJson: stringifyJson(group.billingDiscountPolicy),
|
||||
rechargeDiscountFactor: discountFactorText(group.rechargeDiscountPolicy),
|
||||
rechargeDiscountPolicy: group.rechargeDiscountPolicy ?? {},
|
||||
billingDiscountFactor: discountFactorText(group.billingDiscountPolicy),
|
||||
billingDiscountPolicy: group.billingDiscountPolicy ?? {},
|
||||
rateLimitPolicyJson: stringifyJson(group.rateLimitPolicy),
|
||||
quotaPolicyJson: stringifyJson(group.quotaPolicy),
|
||||
metadataJson: stringifyJson(group.metadata),
|
||||
@ -801,8 +807,8 @@ function formToUserGroupPayload(form: UserGroupForm): UserGroupUpsertRequest {
|
||||
description: form.description.trim() || undefined,
|
||||
source: form.source,
|
||||
priority: Number(form.priority) || 100,
|
||||
rechargeDiscountPolicy: parseJsonObject(form.rechargeDiscountPolicyJson, '充值折扣策略 JSON'),
|
||||
billingDiscountPolicy: parseJsonObject(form.billingDiscountPolicyJson, '计费折扣策略 JSON'),
|
||||
rechargeDiscountPolicy: discountPolicyPayload(form.rechargeDiscountPolicy, form.rechargeDiscountFactor, '充值折扣系数'),
|
||||
billingDiscountPolicy: discountPolicyPayload(form.billingDiscountPolicy, form.billingDiscountFactor, '计费折扣系数'),
|
||||
rateLimitPolicy: parseJsonObject(form.rateLimitPolicyJson, '限流策略 JSON'),
|
||||
quotaPolicy: parseJsonObject(form.quotaPolicyJson, '额度策略 JSON'),
|
||||
metadata: parseJsonObject(form.metadataJson, '元数据 JSON'),
|
||||
@ -854,14 +860,57 @@ function newIdempotencyKey() {
|
||||
}
|
||||
|
||||
function discountSummary(group: UserGroup) {
|
||||
const billing = group.billingDiscountPolicy?.discountFactor ?? group.billingDiscountPolicy?.factor;
|
||||
const recharge = group.rechargeDiscountPolicy?.discountFactor ?? group.rechargeDiscountPolicy?.factor;
|
||||
const billing = discountFactorFromPolicy(group.billingDiscountPolicy);
|
||||
const recharge = discountFactorFromPolicy(group.rechargeDiscountPolicy);
|
||||
const parts = [];
|
||||
if (billing) parts.push(`计费 ${billing}`);
|
||||
if (recharge) parts.push(`充值 ${recharge}`);
|
||||
if (billing) parts.push(`计费 ${trimNumber(billing)}`);
|
||||
if (recharge) parts.push(`充值 ${trimNumber(recharge)}`);
|
||||
return parts.join(' / ') || '未设置';
|
||||
}
|
||||
|
||||
function discountFactorText(policy?: Record<string, unknown>) {
|
||||
const value = discountFactorFromPolicy(policy);
|
||||
return value ? trimNumber(value) : '1';
|
||||
}
|
||||
|
||||
function discountFactorFromPolicy(policy?: Record<string, unknown>) {
|
||||
return numberFromUnknown(policy?.discountFactor) ?? numberFromUnknown(policy?.factor);
|
||||
}
|
||||
|
||||
function discountPolicyPayload(basePolicy: Record<string, unknown>, discountText: string, label: string) {
|
||||
const policy = { ...basePolicy };
|
||||
delete policy.discountFactor;
|
||||
delete policy.factor;
|
||||
const discount = optionalPositiveNumber(discountText, label);
|
||||
if (discount && discount !== 1) {
|
||||
policy.discountFactor = discount;
|
||||
}
|
||||
return Object.keys(policy).length ? policy : undefined;
|
||||
}
|
||||
|
||||
function optionalPositiveNumber(value: string, label: string) {
|
||||
const text = value.trim();
|
||||
if (!text) return undefined;
|
||||
const parsed = Number(text);
|
||||
if (!Number.isFinite(parsed) || parsed <= 0) {
|
||||
throw new Error(`${label} 必须是大于 0 的数字`);
|
||||
}
|
||||
return parsed;
|
||||
}
|
||||
|
||||
function numberFromUnknown(value: unknown) {
|
||||
if (typeof value === 'number' && Number.isFinite(value) && value > 0) return value;
|
||||
if (typeof value === 'string' && value.trim()) {
|
||||
const parsed = Number(value);
|
||||
if (Number.isFinite(parsed) && parsed > 0) return parsed;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function trimNumber(value: number) {
|
||||
return value.toFixed(6).replace(/\.?0+$/, '');
|
||||
}
|
||||
|
||||
function policyKeys(value?: Record<string, unknown>) {
|
||||
if (!value) return [];
|
||||
return Object.keys(value).slice(0, 3);
|
||||
|
||||
@ -630,6 +630,15 @@
|
||||
color: #526170;
|
||||
}
|
||||
|
||||
.composerEstimatedCharge[data-state="loading"] {
|
||||
color: #6f7c8a;
|
||||
}
|
||||
|
||||
.composerEstimatedCharge[data-state="error"],
|
||||
.composerEstimatedCharge[data-state="error"] svg {
|
||||
color: #b45309;
|
||||
}
|
||||
|
||||
.composerMediaSendButton.shButton {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
|
||||
@ -506,6 +506,32 @@ export interface PlayableGatewayApiKey extends GatewayApiKey {
|
||||
secret: string;
|
||||
}
|
||||
|
||||
export interface GatewayPricingEstimateItem {
|
||||
amount?: number;
|
||||
currency?: string;
|
||||
discountFactor?: number;
|
||||
durationSeconds?: number;
|
||||
durationUnit?: string;
|
||||
durationUnitCount?: number;
|
||||
model?: string;
|
||||
modelAlias?: string;
|
||||
platformId?: string;
|
||||
platformModelId?: string;
|
||||
provider?: string;
|
||||
quantity?: number | string;
|
||||
resourceType?: string;
|
||||
simulated?: boolean;
|
||||
unit?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface GatewayPricingEstimate {
|
||||
items: GatewayPricingEstimateItem[];
|
||||
resolver: string;
|
||||
totalAmount?: number;
|
||||
currency?: string;
|
||||
}
|
||||
|
||||
export interface GatewayWalletAccount {
|
||||
id: string;
|
||||
gatewayTenantId?: string;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user