fix(gateway): normalize model alias billing identity
This commit is contained in:
parent
f5c69b9852
commit
2aeb47d6a5
@ -5,11 +5,14 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
)
|
||||
|
||||
func (s *Store) ListModelCandidates(ctx context.Context, model string, modelType string, user *auth.User) ([]RuntimeModelCandidate, error) {
|
||||
exactModel := strings.TrimSpace(model)
|
||||
modelMatchKey := normalizeModelMatchKey(exactModel)
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT p.id::text, p.platform_key, p.name, p.provider,
|
||||
COALESCE(NULLIF(p.config->>'specType', ''), NULLIF(cp.provider_type, ''), NULLIF(p.config->>'sourceSpecType', ''), p.provider) AS spec_type,
|
||||
@ -101,13 +104,30 @@ WHERE p.status = 'enabled'
|
||||
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
|
||||
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
|
||||
AND (
|
||||
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
|
||||
(
|
||||
COALESCE(m.model_alias, '') <> ''
|
||||
AND (
|
||||
m.model_alias = $1::text
|
||||
OR (
|
||||
NULLIF($3::text, '') IS NOT NULL
|
||||
AND regexp_replace(COALESCE(m.model_alias, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
)
|
||||
)
|
||||
)
|
||||
OR (
|
||||
COALESCE(m.model_alias, '') = ''
|
||||
AND (
|
||||
m.model_name = $1::text
|
||||
OR b.canonical_model_key = $1::text
|
||||
OR b.provider_model_name = $1::text
|
||||
OR (
|
||||
NULLIF($3::text, '') IS NOT NULL
|
||||
AND (
|
||||
regexp_replace(COALESCE(m.model_name, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
OR regexp_replace(COALESCE(b.canonical_model_key, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
OR regexp_replace(COALESCE(b.provider_model_name, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -115,7 +135,7 @@ ORDER BY effective_priority ASC,
|
||||
COALESCE(s.running_count, 0) ASC,
|
||||
COALESCE(s.waiting_count, 0) ASC,
|
||||
COALESCE(s.last_assigned_at, to_timestamp(0)) ASC,
|
||||
m.created_at ASC`, model, modelType)
|
||||
m.created_at ASC`, exactModel, modelType, modelMatchKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -358,6 +378,8 @@ func runtimeCandidateFull(candidate RuntimeModelCandidate) bool {
|
||||
}
|
||||
|
||||
func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) {
|
||||
exactModel := strings.TrimSpace(model)
|
||||
modelMatchKey := normalizeModelMatchKey(exactModel)
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT p.name,
|
||||
COALESCE(NULLIF(m.display_name, ''), NULLIF(m.model_alias, ''), m.model_name),
|
||||
@ -373,19 +395,36 @@ WHERE p.status = 'enabled'
|
||||
AND m.enabled = true
|
||||
AND m.model_type @> jsonb_build_array($2::text)
|
||||
AND (
|
||||
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
|
||||
(
|
||||
COALESCE(m.model_alias, '') <> ''
|
||||
AND (
|
||||
m.model_alias = $1::text
|
||||
OR (
|
||||
NULLIF($3::text, '') IS NOT NULL
|
||||
AND regexp_replace(COALESCE(m.model_alias, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
)
|
||||
)
|
||||
)
|
||||
OR (
|
||||
COALESCE(m.model_alias, '') = ''
|
||||
AND (
|
||||
m.model_name = $1::text
|
||||
OR b.canonical_model_key = $1::text
|
||||
OR b.provider_model_name = $1::text
|
||||
OR (
|
||||
NULLIF($3::text, '') IS NOT NULL
|
||||
AND (
|
||||
regexp_replace(COALESCE(m.model_name, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
OR regexp_replace(COALESCE(b.canonical_model_key, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
OR regexp_replace(COALESCE(b.provider_model_name, ''), '[[:space:]]+', '', 'g') = $3::text
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
ORDER BY GREATEST(COALESCE(p.cooldown_until, to_timestamp(0)), COALESCE(m.cooldown_until, to_timestamp(0))) DESC,
|
||||
p.priority ASC,
|
||||
m.created_at ASC`, model, modelType)
|
||||
m.created_at ASC`, exactModel, modelType, modelMatchKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -442,3 +481,19 @@ func cooldownErrorMessage(scope string, name string, remainingSeconds float64, c
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func normalizeModelMatchKey(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
var builder strings.Builder
|
||||
builder.Grow(len(value))
|
||||
for _, char := range value {
|
||||
if unicode.IsSpace(char) {
|
||||
continue
|
||||
}
|
||||
builder.WriteRune(char)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
@ -2,6 +2,52 @@ package store
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeModelMatchKeyRemovesWhitespace(t *testing.T) {
|
||||
got := normalizeModelMatchKey(" doubao-5.0 图像\t编辑\n")
|
||||
if got != "doubao-5.0图像编辑" {
|
||||
t.Fatalf("expected whitespace-insensitive model key, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskBillingModelIdentityPrefersSystemAlias(t *testing.T) {
|
||||
identity := taskBillingModelIdentity(GatewayTask{
|
||||
Model: "doubao-5.0 图像编辑",
|
||||
RequestedModel: "doubao-5.0 图像编辑",
|
||||
ResolvedModel: "doubao-image-real",
|
||||
Metrics: map[string]any{
|
||||
"modelAlias": "doubao-5.0图像编辑",
|
||||
"modelName": "doubao-image-real",
|
||||
"providerModel": "doubao-provider-image",
|
||||
},
|
||||
})
|
||||
|
||||
if identity.Model != "doubao-5.0图像编辑" || identity.ResolvedModel != "doubao-5.0图像编辑" {
|
||||
t.Fatalf("expected persisted model to use system alias, got %+v", identity)
|
||||
}
|
||||
if identity.RequestedModel != "doubao-5.0 图像编辑" {
|
||||
t.Fatalf("expected requested model to preserve original request, got %+v", identity)
|
||||
}
|
||||
if identity.ModelName != "doubao-image-real" || identity.ProviderModelName != "doubao-provider-image" {
|
||||
t.Fatalf("expected model name/provider model to stay available, got %+v", identity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskBillingModelIdentityFallsBackToBillingLines(t *testing.T) {
|
||||
identity := taskBillingModelIdentity(GatewayTask{
|
||||
Model: "front end alias",
|
||||
Billings: []any{
|
||||
map[string]any{
|
||||
"model": "system-model-name",
|
||||
"modelAlias": "System Model Alias",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if identity.Model != "System Model Alias" || identity.ModelName != "system-model-name" {
|
||||
t.Fatalf("expected billing lines to provide system model identity, got %+v", identity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeCandidateLoadUsesMaxLimitedMetric(t *testing.T) {
|
||||
candidate := RuntimeModelCandidate{}
|
||||
applyRuntimeCandidateLoad(&candidate, runtimeCandidateLoadInput{
|
||||
|
||||
@ -687,11 +687,16 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error {
|
||||
if currency == "" || currency == "mixed" {
|
||||
currency = "resource"
|
||||
}
|
||||
modelIdentity := taskBillingModelIdentity(task)
|
||||
metadataMap := map[string]any{
|
||||
"taskId": task.ID,
|
||||
"kind": task.Kind,
|
||||
"model": task.Model,
|
||||
"resolvedModel": task.ResolvedModel,
|
||||
"model": modelIdentity.Model,
|
||||
"requestedModel": modelIdentity.RequestedModel,
|
||||
"resolvedModel": modelIdentity.ResolvedModel,
|
||||
"modelName": modelIdentity.ModelName,
|
||||
"modelAlias": modelIdentity.ModelAlias,
|
||||
"providerModel": modelIdentity.ProviderModelName,
|
||||
"billings": task.Billings,
|
||||
"billingSummary": task.BillingSummary,
|
||||
}
|
||||
@ -814,7 +819,7 @@ func roundMoney(value float64) float64 {
|
||||
|
||||
func taskBillingString(value any) string {
|
||||
if text, ok := value.(string); ok {
|
||||
return text
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -203,14 +203,27 @@ SET frozen_balance = $2,
|
||||
WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil {
|
||||
return err
|
||||
}
|
||||
modelIdentity := taskBillingModelIdentity(GatewayTask{
|
||||
Kind: task.Kind,
|
||||
Model: task.Model,
|
||||
RequestedModel: task.RequestedModel,
|
||||
ResolvedModel: task.ResolvedModel,
|
||||
Billings: billings,
|
||||
Metrics: task.Metrics,
|
||||
})
|
||||
metadata, _ := json.Marshal(map[string]any{
|
||||
"taskId": taskID,
|
||||
"kind": task.Kind,
|
||||
"model": task.Model,
|
||||
"reserved": amount,
|
||||
"balance": roundMoney(locked.Balance),
|
||||
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||
"frozenAfter": frozenAfter,
|
||||
"taskId": taskID,
|
||||
"kind": task.Kind,
|
||||
"model": modelIdentity.Model,
|
||||
"requestedModel": modelIdentity.RequestedModel,
|
||||
"resolvedModel": modelIdentity.ResolvedModel,
|
||||
"modelName": modelIdentity.ModelName,
|
||||
"modelAlias": modelIdentity.ModelAlias,
|
||||
"providerModel": modelIdentity.ProviderModelName,
|
||||
"reserved": amount,
|
||||
"balance": roundMoney(locked.Balance),
|
||||
"frozenBefore": roundMoney(locked.FrozenBalance),
|
||||
"frozenAfter": frozenAfter,
|
||||
})
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO gateway_wallet_transactions (
|
||||
@ -484,9 +497,11 @@ SELECT t.id::text, t.account_id::text, a.currency, COALESCE(t.gateway_tenant_id:
|
||||
t.metadata || jsonb_strip_nulls(jsonb_build_object(
|
||||
'taskId', task.id::text,
|
||||
'kind', task.kind,
|
||||
'model', task.model,
|
||||
'model', COALESCE(NULLIF(platform_model.model_alias, ''), NULLIF(task.metrics->>'modelAlias', ''), NULLIF(task.resolved_model, ''), NULLIF(platform_model.model_name, ''), NULLIF(task.metrics->>'modelName', ''), task.model),
|
||||
'requestedModel', task.requested_model,
|
||||
'resolvedModel', task.resolved_model,
|
||||
'resolvedModel', COALESCE(NULLIF(platform_model.model_alias, ''), NULLIF(task.metrics->>'modelAlias', ''), NULLIF(task.resolved_model, ''), NULLIF(platform_model.model_name, ''), NULLIF(task.metrics->>'modelName', '')),
|
||||
'modelName', COALESCE(NULLIF(platform_model.model_name, ''), NULLIF(task.metrics->>'modelName', ''), NULLIF(task.resolved_model, '')),
|
||||
'modelAlias', COALESCE(NULLIF(platform_model.model_alias, ''), NULLIF(task.metrics->>'modelAlias', '')),
|
||||
'modelType', task.model_type,
|
||||
'taskStatus', task.status,
|
||||
'runMode', task.run_mode,
|
||||
@ -811,6 +826,54 @@ func walletFloat(value any) float64 {
|
||||
}
|
||||
}
|
||||
|
||||
type billingModelIdentity struct {
|
||||
Model string
|
||||
RequestedModel string
|
||||
ResolvedModel string
|
||||
ModelName string
|
||||
ModelAlias string
|
||||
ProviderModelName string
|
||||
}
|
||||
|
||||
func taskBillingModelIdentity(task GatewayTask) billingModelIdentity {
|
||||
modelAlias := firstNonEmpty(
|
||||
taskBillingString(task.Metrics["modelAlias"]),
|
||||
firstBillingLineString(task.Billings, "modelAlias"),
|
||||
)
|
||||
modelName := firstNonEmpty(
|
||||
taskBillingString(task.Metrics["modelName"]),
|
||||
taskBillingString(task.Metrics["resolvedModel"]),
|
||||
task.ResolvedModel,
|
||||
firstBillingLineString(task.Billings, "model"),
|
||||
)
|
||||
providerModelName := firstNonEmpty(
|
||||
taskBillingString(task.Metrics["providerModel"]),
|
||||
firstBillingLineString(task.Billings, "providerModel"),
|
||||
)
|
||||
systemModel := firstNonEmpty(modelAlias, modelName, task.ResolvedModel, task.Model)
|
||||
return billingModelIdentity{
|
||||
Model: systemModel,
|
||||
RequestedModel: firstNonEmpty(task.RequestedModel, task.Model),
|
||||
ResolvedModel: systemModel,
|
||||
ModelName: modelName,
|
||||
ModelAlias: modelAlias,
|
||||
ProviderModelName: providerModelName,
|
||||
}
|
||||
}
|
||||
|
||||
func firstBillingLineString(billings []any, key string) string {
|
||||
for _, raw := range billings {
|
||||
line, _ := raw.(map[string]any)
|
||||
if line == nil {
|
||||
continue
|
||||
}
|
||||
if value := strings.TrimSpace(taskBillingString(line[key])); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeWalletCurrency(currency string) string {
|
||||
currency = strings.TrimSpace(currency)
|
||||
if currency == "" {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user