149 lines
4.8 KiB
Go
149 lines
4.8 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
type EstimateResult struct {
|
|
Items []any `json:"items"`
|
|
Resolver string `json:"resolver"`
|
|
}
|
|
|
|
func (s *Service) Estimate(ctx context.Context, kind string, model string, body map[string]any, user *auth.User) (EstimateResult, error) {
|
|
candidates, err := s.store.ListModelCandidates(ctx, model, modelTypeFromKind(kind))
|
|
if err != nil {
|
|
return EstimateResult{}, err
|
|
}
|
|
candidate := candidates[0]
|
|
usage := clients.Usage{InputTokens: estimateRequestTokens(body), OutputTokens: int(floatFromAny(body["max_tokens"]))}
|
|
if usage.OutputTokens == 0 {
|
|
usage.OutputTokens = 64
|
|
}
|
|
usage.TotalTokens = usage.InputTokens + usage.OutputTokens
|
|
response := clients.Response{Usage: usage, Result: map[string]any{"usage": map[string]any{
|
|
"prompt_tokens": usage.InputTokens,
|
|
"completion_tokens": usage.OutputTokens,
|
|
"total_tokens": usage.TotalTokens,
|
|
}}}
|
|
return EstimateResult{
|
|
Items: s.billings(ctx, user, kind, body, candidate, response, true),
|
|
Resolver: "effective-pricing-v1",
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) billings(ctx context.Context, user *auth.User, kind string, body map[string]any, candidate store.RuntimeModelCandidate, response clients.Response, simulated bool) []any {
|
|
config := effectiveBillingConfig(candidate)
|
|
discount := effectiveDiscount(ctx, s.store, user, candidate)
|
|
if modelTypeFromKind(kind) == "chat" {
|
|
inputTokens := response.Usage.InputTokens
|
|
outputTokens := response.Usage.OutputTokens
|
|
if inputTokens == 0 && outputTokens == 0 {
|
|
inputTokens = estimateRequestTokens(body)
|
|
outputTokens = 1
|
|
}
|
|
inputAmount := roundPrice(float64(inputTokens) / 1000 * price(config, "textInputPer1k") * discount)
|
|
outputAmount := roundPrice(float64(outputTokens) / 1000 * price(config, "textOutputPer1k") * discount)
|
|
return []any{
|
|
billingLine(candidate, "text_input", "1k_tokens", inputTokens, inputAmount, discount, simulated),
|
|
billingLine(candidate, "text_output", "1k_tokens", outputTokens, outputAmount, discount, simulated),
|
|
}
|
|
}
|
|
count := int(floatFromAny(body["n"]))
|
|
if count <= 0 {
|
|
count = 1
|
|
}
|
|
resource := "image"
|
|
baseKey := "imageBase"
|
|
if kind == "images.edits" {
|
|
resource = "image_edit"
|
|
baseKey = "editBase"
|
|
}
|
|
amount := float64(count) * price(config, baseKey) * weighted(config, "qualityWeights", stringFromMap(body, "quality")) * weighted(config, "sizeWeights", stringFromMap(body, "size")) * discount
|
|
return []any{billingLine(candidate, resource, "image", count, roundPrice(amount), discount, simulated)}
|
|
}
|
|
|
|
func effectiveBillingConfig(candidate store.RuntimeModelCandidate) map[string]any {
|
|
base := candidate.BaseBillingConfig
|
|
if len(candidate.BillingConfig) > 0 {
|
|
base = candidate.BillingConfig
|
|
}
|
|
if len(candidate.BillingConfigOverride) > 0 {
|
|
base = mergeMap(base, candidate.BillingConfigOverride)
|
|
}
|
|
return base
|
|
}
|
|
|
|
func effectiveDiscount(ctx context.Context, db *store.Store, user *auth.User, candidate store.RuntimeModelCandidate) float64 {
|
|
discount := candidate.DefaultDiscountFactor
|
|
if candidate.DiscountFactor > 0 {
|
|
discount = candidate.DiscountFactor
|
|
}
|
|
if discount <= 0 {
|
|
discount = 1
|
|
}
|
|
if group, err := db.ResolveUserGroupPolicy(ctx, user); err == nil {
|
|
groupDiscount := floatFromAny(group.BillingDiscountPolicy["discountFactor"])
|
|
if groupDiscount > 0 {
|
|
discount *= groupDiscount
|
|
}
|
|
}
|
|
return discount
|
|
}
|
|
|
|
func billingLine(candidate store.RuntimeModelCandidate, resourceType string, unit string, quantity any, amount float64, discount float64, simulated bool) map[string]any {
|
|
return map[string]any{
|
|
"model": candidate.ModelName,
|
|
"modelAlias": candidate.ModelAlias,
|
|
"provider": candidate.Provider,
|
|
"platformId": candidate.PlatformID,
|
|
"platformModelId": candidate.PlatformModelID,
|
|
"resourceType": resourceType,
|
|
"unit": unit,
|
|
"quantity": quantity,
|
|
"amount": amount,
|
|
"currency": "resource",
|
|
"discountFactor": discount,
|
|
"simulated": simulated,
|
|
}
|
|
}
|
|
|
|
func price(config map[string]any, key string) float64 {
|
|
value := floatFromAny(config[key])
|
|
if value > 0 {
|
|
return value
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func weighted(config map[string]any, key string, name string) float64 {
|
|
if strings.TrimSpace(name) == "" {
|
|
return 1
|
|
}
|
|
weights, _ := config[key].(map[string]any)
|
|
if value := floatFromAny(weights[name]); value > 0 {
|
|
return value
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func roundPrice(value float64) float64 {
|
|
return math.Round(value*1000000) / 1000000
|
|
}
|
|
|
|
func mergeMap(base map[string]any, override map[string]any) map[string]any {
|
|
out := map[string]any{}
|
|
for key, value := range base {
|
|
out[key] = value
|
|
}
|
|
for key, value := range override {
|
|
out[key] = value
|
|
}
|
|
return out
|
|
}
|