easyai-ai-gateway/apps/api/internal/runner/pricing.go

445 lines
13 KiB
Go

package runner
import (
"context"
"math"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type EstimateResult struct {
Items []any `json:"items"`
Resolver string `json:"resolver"`
TotalAmount float64 `json:"totalAmount"`
Currency string `json:"currency"`
}
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
body = normalizeRequest(kind, body)
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind, body), user)
if err != nil {
return EstimateResult{}, err
}
candidate := candidates[0]
body = preprocessRequest(kind, body, candidate)
items := s.estimatedBillings(ctx, user, kind, body, candidate)
return EstimateResult{
Items: items,
Resolver: "effective-pricing-v1",
TotalAmount: totalBillingAmount(items),
Currency: billingCurrency(items),
}, nil
}
func (s *Service) estimatedBillings(ctx context.Context, user *auth.User, kind string, body map[string]any, candidate store.RuntimeModelCandidate) []any {
usage := clients.Usage{InputTokens: estimateRequestTokens(body), OutputTokens: int(floatFromAny(body["max_tokens"]))}
if usage.OutputTokens == 0 {
usage.OutputTokens = 64
}
usage.TotalTokens = usage.InputTokens + usage.OutputTokens
response := clients.Response{Usage: usage, Result: map[string]any{"usage": map[string]any{
"prompt_tokens": usage.InputTokens,
"completion_tokens": usage.OutputTokens,
"total_tokens": usage.TotalTokens,
}}}
return s.billings(ctx, user, kind, body, candidate, response, true)
}
func (s *Service) billings(ctx context.Context, user *auth.User, kind string, body map[string]any, candidate store.RuntimeModelCandidate, response clients.Response, simulated bool) []any {
config := s.effectiveBillingConfig(ctx, candidate)
discount := effectiveDiscount(ctx, s.store, user, candidate)
if isTextGenerationKind(kind) {
inputTokens := response.Usage.InputTokens
outputTokens := response.Usage.OutputTokens
if inputTokens == 0 && outputTokens == 0 {
inputTokens = estimateRequestTokens(body)
outputTokens = 1
}
inputAmount := roundPrice(float64(inputTokens) / 1000 * resourcePrice(config, "text", "textInputPer1k", "inputTokenPrice", "basePrice") * discount)
outputAmount := roundPrice(float64(outputTokens) / 1000 * resourcePrice(config, "text", "textOutputPer1k", "outputTokenPrice", "basePrice") * discount)
return []any{
billingLine(candidate, "text_input", "1k_tokens", inputTokens, inputAmount, discount, simulated),
billingLine(candidate, "text_output", "1k_tokens", outputTokens, outputAmount, discount, simulated),
}
}
count := requestOutputCount(body)
resource := "image"
unit := "image"
baseKey := "imageBase"
if kind == "images.edits" {
resource = "image_edit"
baseKey = "editBase"
}
if kind == "videos.generations" {
resource = "video"
unit = "5s_video"
baseKey = "videoBase"
duration := requestDurationSeconds(body)
durationUnits := math.Max(1, math.Ceil(duration/5))
amount := float64(count) *
durationUnits *
resourcePrice(config, resource, baseKey, "basePrice") *
resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) *
resourceWeight(config, resource, "audioWeights", boolWeightKey(boolishValue(body["audio"]))) *
resourceWeight(config, resource, "referenceVideoWeights", boolWeightKey(requestHasReferenceVideo(body))) *
resourceWeight(config, resource, "voiceSpecifiedWeights", boolWeightKey(requestHasVoiceID(body))) *
discount
return []any{billingLineWithDetails(candidate, resource, unit, count*int(durationUnits), roundPrice(amount), discount, simulated, map[string]any{
"count": count,
"durationSeconds": duration,
"durationUnit": "5s",
"durationUnitCount": durationUnits,
})}
}
amount := float64(count) * resourcePrice(config, resource, baseKey, "basePrice") * resourceWeight(config, resource, "qualityWeights", stringFromMap(body, "quality")) * resourceWeight(config, resource, "sizeWeights", stringFromMap(body, "size")) * resourceWeight(config, resource, "resolutionWeights", firstNonEmptyString(stringFromMap(body, "resolution"), stringFromMap(body, "size"))) * discount
return []any{billingLine(candidate, resource, unit, count, roundPrice(amount), discount, simulated)}
}
func (s *Service) effectiveBillingConfig(ctx context.Context, candidate store.RuntimeModelCandidate) map[string]any {
base := candidate.BaseBillingConfig
if ruleSetID := firstNonEmptyString(candidate.BasePricingRuleSetID, candidate.PlatformPricingRuleSetID); ruleSetID != "" {
if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, ruleSetID); err == nil && len(ruleSetConfig) > 0 {
base = ruleSetConfig
}
}
if len(candidate.BillingConfig) > 0 {
base = candidate.BillingConfig
}
if candidate.ModelPricingRuleSetID != "" {
if ruleSetConfig, err := s.store.PricingRuleSetBillingConfig(ctx, candidate.ModelPricingRuleSetID); err == nil && len(ruleSetConfig) > 0 {
base = ruleSetConfig
}
}
if len(candidate.BillingConfigOverride) > 0 {
base = mergeMap(base, candidate.BillingConfigOverride)
}
return base
}
func effectiveDiscount(ctx context.Context, db *store.Store, user *auth.User, candidate store.RuntimeModelCandidate) float64 {
discount := candidate.DefaultDiscountFactor
if candidate.DiscountFactor > 0 {
discount = candidate.DiscountFactor
}
if discount <= 0 {
discount = 1
}
if db != nil {
if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil {
groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"])
if groupDiscount > 0 {
discount *= groupDiscount
}
}
}
return discount
}
func billingLine(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool) map[string]any {
return billingLineWithDetails(candidate, resourceType, unit, quantity, amount, discount, simulated, nil)
}
func billingLineWithDetails(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool, details map[string]any) map[string]any {
line := map[string]any{
"model": candidate.ModelName,
"modelAlias": candidate.ModelAlias,
"provider": candidate.Provider,
"platformId": candidate.PlatformID,
"platformModelId": candidate.PlatformModelID,
"resourceType": resourceType,
"unit": unit,
"quantity": quantity,
"amount": amount,
"currency": "resource",
"discountFactor": discount,
"simulated": simulated,
}
for key, value := range details {
line[key] = value
}
return line
}
func price(config map[string]any, key string) float64 {
value := floatFromAny(config[key])
if value > 0 {
return value
}
return 0
}
func resourcePrice(config map[string]any, resource string, keys ...string) float64 {
for _, key := range keys {
if value := price(config, key); value > 0 {
return value
}
}
if resourceConfig, ok := config[resource].(map[string]any); ok {
for _, key := range keys {
if value := floatFromAny(resourceConfig[key]); value > 0 {
return value
}
}
if value := floatFromAny(resourceConfig["basePrice"]); value > 0 {
return value
}
}
if resource == "image_edit" {
return resourcePrice(config, "image", keys...)
}
return 0
}
func weighted(config map[string]any, key string, name string) float64 {
if strings.TrimSpace(name) == "" {
return 1
}
weights, _ := config[key].(map[string]any)
if value := floatFromAny(weights[name]); value > 0 {
return value
}
return 1
}
func resourceWeight(config map[string]any, resource string, key string, name string) float64 {
keys := weightKeyAliases(key)
names := weightValueAliases(key, name)
for _, candidateKey := range keys {
for _, candidateName := range names {
if value := weighted(config, candidateKey, candidateName); value != 1 {
return value
}
}
}
if value := dynamicWeight(config["dynamicWeight"], keys, names); value != 1 {
return value
}
if strings.TrimSpace(name) == "" {
return 1
}
resourceConfig, _ := config[resource].(map[string]any)
if len(resourceConfig) == 0 && resource == "image_edit" {
resourceConfig, _ = config["image"].(map[string]any)
}
if value := dynamicWeight(resourceConfig["dynamicWeight"], keys, names); value != 1 {
return value
}
for _, candidateKey := range keys {
if weights, ok := resourceConfig[candidateKey].(map[string]any); ok {
for _, candidateName := range names {
if value := floatFromAny(weights[candidateName]); value > 0 {
return value
}
}
}
}
return 1
}
func dynamicWeight(value any, keys []string, names []string) float64 {
if len(names) == 0 {
return 1
}
weights, _ := value.(map[string]any)
if len(weights) == 0 {
return 1
}
for _, name := range names {
if direct := floatFromAny(weights[name]); direct > 0 {
return direct
}
}
for _, key := range keys {
if nested, ok := weights[key].(map[string]any); ok {
for _, name := range names {
if nestedValue := floatFromAny(nested[name]); nestedValue > 0 {
return nestedValue
}
}
}
}
return 1
}
func weightKeyAliases(key string) []string {
switch key {
case "qualityWeights":
return []string{"qualityWeights", "qualityFactors"}
case "resolutionWeights":
return []string{"resolutionWeights", "resolutionFactors"}
case "audioWeights":
return []string{"audioWeights", "audioFactors"}
case "referenceVideoWeights":
return []string{"referenceVideoWeights", "referenceVideoFactors"}
case "voiceSpecifiedWeights":
return []string{"voiceSpecifiedWeights", "voiceSpecifiedFactors"}
default:
return []string{key}
}
}
func weightValueAliases(key string, name string) []string {
name = strings.TrimSpace(name)
if name == "" {
return nil
}
switch key {
case "audioWeights":
return []string{name, "audio-" + name}
case "referenceVideoWeights":
return []string{name, "reference-video-" + name}
case "voiceSpecifiedWeights":
return []string{name, "voice-specified-" + name}
default:
return []string{name}
}
}
func requestOutputCount(body map[string]any) int {
for _, key := range []string{"n", "count", "batch_size", "batchSize"} {
if value := int(math.Ceil(floatFromAny(body[key]))); value > 0 {
return value
}
}
return 1
}
func requestDurationSeconds(body map[string]any) float64 {
for _, key := range []string{"duration", "durationSeconds", "duration_seconds"} {
if value := floatFromAny(body[key]); value > 0 {
return value
}
}
for _, value := range body {
items, ok := value.([]any)
if !ok || len(items) == 0 {
continue
}
total := 0.0
allDurationItems := true
for _, item := range items {
record, ok := item.(map[string]any)
if !ok {
allDurationItems = false
break
}
duration := floatFromAny(record["duration"])
if duration <= 0 {
allDurationItems = false
break
}
total += duration
}
if allDurationItems && total > 0 {
return total
}
}
return 5
}
func requestHasReferenceVideo(body map[string]any) bool {
if hasNonEmptyArray(body["video_list"]) || hasNonEmptyArray(body["videoList"]) {
return true
}
if firstNonEmptyStringValue(body, "video", "video_url", "videoUrl", "reference_video", "referenceVideo") != "" {
return true
}
content, _ := body["content"].([]any)
for _, item := range content {
record, _ := item.(map[string]any)
if len(record) == 0 {
continue
}
itemType := strings.TrimSpace(stringFromAny(record["type"]))
role := strings.TrimSpace(stringFromAny(record["role"]))
if itemType == "video_url" || role == "video_feature" || role == "video_base" || role == "reference_video" {
return true
}
}
return false
}
func requestHasVoiceID(body map[string]any) bool {
return boolishValue(body["audio"]) && firstNonEmptyStringValue(body, "voice_id", "voiceId") != ""
}
func boolWeightKey(value bool) string {
if value {
return "true"
}
return "false"
}
func boolishValue(value any) bool {
switch typed := value.(type) {
case bool:
return typed
case string:
switch strings.ToLower(strings.TrimSpace(typed)) {
case "true", "1", "yes", "on":
return true
default:
return false
}
case int:
return typed != 0
case int64:
return typed != 0
case float64:
return typed != 0
default:
return false
}
}
func hasNonEmptyArray(value any) bool {
items, ok := value.([]any)
return ok && len(items) > 0
}
func totalBillingAmount(items []any) float64 {
total := 0.0
for _, raw := range items {
line, _ := raw.(map[string]any)
total += floatFromAny(line["amount"])
}
return roundPrice(total)
}
func billingCurrency(items []any) string {
for _, raw := range items {
line, _ := raw.(map[string]any)
if currency := stringFromAny(line["currency"]); currency != "" {
return currency
}
}
return "resource"
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func roundPrice(value float64) float64 {
return math.Round(value*1000000) / 1000000
}
func mergeMap(base map[string]any, override map[string]any) map[string]any {
out := map[string]any{}
for key, value := range base {
out[key] = value
}
for key, value := range override {
out[key] = value
}
return out
}