diff --git a/apps/api/internal/store/candidates.go b/apps/api/internal/store/candidates.go index 1ec4e28..27342ea 100644 --- a/apps/api/internal/store/candidates.go +++ b/apps/api/internal/store/candidates.go @@ -5,11 +5,14 @@ import ( "fmt" "sort" "strings" + "unicode" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" ) func (s *Store) ListModelCandidates(ctx context.Context, model string, modelType string, user *auth.User) ([]RuntimeModelCandidate, error) { + exactModel := strings.TrimSpace(model) + modelMatchKey := normalizeModelMatchKey(exactModel) rows, err := s.pool.Query(ctx, ` SELECT p.id::text, p.platform_key, p.name, p.provider, COALESCE(NULLIF(p.config->>'specType', ''), NULLIF(cp.provider_type, ''), NULLIF(p.config->>'sourceSpecType', ''), p.provider) AS spec_type, @@ -101,13 +104,30 @@ WHERE p.status = 'enabled' AND (p.cooldown_until IS NULL OR p.cooldown_until <= now()) AND (m.cooldown_until IS NULL OR m.cooldown_until <= now()) AND ( - (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text) + ( + COALESCE(m.model_alias, '') <> '' + AND ( + m.model_alias = $1::text + OR ( + NULLIF($3::text, '') IS NOT NULL + AND regexp_replace(COALESCE(m.model_alias, ''), '[[:space:]]+', '', 'g') = $3::text + ) + ) + ) OR ( COALESCE(m.model_alias, '') = '' AND ( m.model_name = $1::text OR b.canonical_model_key = $1::text OR b.provider_model_name = $1::text + OR ( + NULLIF($3::text, '') IS NOT NULL + AND ( + regexp_replace(COALESCE(m.model_name, ''), '[[:space:]]+', '', 'g') = $3::text + OR regexp_replace(COALESCE(b.canonical_model_key, ''), '[[:space:]]+', '', 'g') = $3::text + OR regexp_replace(COALESCE(b.provider_model_name, ''), '[[:space:]]+', '', 'g') = $3::text + ) + ) ) ) ) @@ -115,7 +135,7 @@ ORDER BY effective_priority ASC, COALESCE(s.running_count, 0) ASC, COALESCE(s.waiting_count, 0) ASC, COALESCE(s.last_assigned_at, to_timestamp(0)) ASC, - m.created_at ASC`, model, modelType) + m.created_at ASC`, exactModel, modelType, modelMatchKey) if err != nil { return nil, err } @@ -358,6 +378,8 @@ func runtimeCandidateFull(candidate RuntimeModelCandidate) bool { } func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) { + exactModel := strings.TrimSpace(model) + modelMatchKey := normalizeModelMatchKey(exactModel) rows, err := s.pool.Query(ctx, ` SELECT p.name, COALESCE(NULLIF(m.display_name, ''), NULLIF(m.model_alias, ''), m.model_name), @@ -373,19 +395,36 @@ WHERE p.status = 'enabled' AND m.enabled = true AND m.model_type @> jsonb_build_array($2::text) AND ( - (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text) + ( + COALESCE(m.model_alias, '') <> '' + AND ( + m.model_alias = $1::text + OR ( + NULLIF($3::text, '') IS NOT NULL + AND regexp_replace(COALESCE(m.model_alias, ''), '[[:space:]]+', '', 'g') = $3::text + ) + ) + ) OR ( COALESCE(m.model_alias, '') = '' AND ( m.model_name = $1::text OR b.canonical_model_key = $1::text OR b.provider_model_name = $1::text + OR ( + NULLIF($3::text, '') IS NOT NULL + AND ( + regexp_replace(COALESCE(m.model_name, ''), '[[:space:]]+', '', 'g') = $3::text + OR regexp_replace(COALESCE(b.canonical_model_key, ''), '[[:space:]]+', '', 'g') = $3::text + OR regexp_replace(COALESCE(b.provider_model_name, ''), '[[:space:]]+', '', 'g') = $3::text + ) + ) ) ) ) ORDER BY GREATEST(COALESCE(p.cooldown_until, to_timestamp(0)), COALESCE(m.cooldown_until, to_timestamp(0))) DESC, p.priority ASC, - m.created_at ASC`, model, modelType) + m.created_at ASC`, exactModel, modelType, modelMatchKey) if err != nil { return nil, err } @@ -442,3 +481,19 @@ func cooldownErrorMessage(scope string, name string, remainingSeconds float64, c } return message } + +func normalizeModelMatchKey(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + var builder strings.Builder + builder.Grow(len(value)) + for _, char := range value { + if unicode.IsSpace(char) { + continue + } + builder.WriteRune(char) + } + return builder.String() +} diff --git a/apps/api/internal/store/candidates_test.go b/apps/api/internal/store/candidates_test.go index 5b41f7c..6c9e54a 100644 --- a/apps/api/internal/store/candidates_test.go +++ b/apps/api/internal/store/candidates_test.go @@ -2,6 +2,52 @@ package store import "testing" +func TestNormalizeModelMatchKeyRemovesWhitespace(t *testing.T) { + got := normalizeModelMatchKey(" doubao-5.0 图像\t编辑\n") + if got != "doubao-5.0图像编辑" { + t.Fatalf("expected whitespace-insensitive model key, got %q", got) + } +} + +func TestTaskBillingModelIdentityPrefersSystemAlias(t *testing.T) { + identity := taskBillingModelIdentity(GatewayTask{ + Model: "doubao-5.0 图像编辑", + RequestedModel: "doubao-5.0 图像编辑", + ResolvedModel: "doubao-image-real", + Metrics: map[string]any{ + "modelAlias": "doubao-5.0图像编辑", + "modelName": "doubao-image-real", + "providerModel": "doubao-provider-image", + }, + }) + + if identity.Model != "doubao-5.0图像编辑" || identity.ResolvedModel != "doubao-5.0图像编辑" { + t.Fatalf("expected persisted model to use system alias, got %+v", identity) + } + if identity.RequestedModel != "doubao-5.0 图像编辑" { + t.Fatalf("expected requested model to preserve original request, got %+v", identity) + } + if identity.ModelName != "doubao-image-real" || identity.ProviderModelName != "doubao-provider-image" { + t.Fatalf("expected model name/provider model to stay available, got %+v", identity) + } +} + +func TestTaskBillingModelIdentityFallsBackToBillingLines(t *testing.T) { + identity := taskBillingModelIdentity(GatewayTask{ + Model: "front end alias", + Billings: []any{ + map[string]any{ + "model": "system-model-name", + "modelAlias": "System Model Alias", + }, + }, + }) + + if identity.Model != "System Model Alias" || identity.ModelName != "system-model-name" { + t.Fatalf("expected billing lines to provide system model identity, got %+v", identity) + } +} + func TestRuntimeCandidateLoadUsesMaxLimitedMetric(t *testing.T) { candidate := RuntimeModelCandidate{} applyRuntimeCandidateLoad(&candidate, runtimeCandidateLoadInput{ diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index 16bcaad..423d271 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -687,11 +687,16 @@ func (s *Store) SettleTaskBilling(ctx context.Context, task GatewayTask) error { if currency == "" || currency == "mixed" { currency = "resource" } + modelIdentity := taskBillingModelIdentity(task) metadataMap := map[string]any{ "taskId": task.ID, "kind": task.Kind, - "model": task.Model, - "resolvedModel": task.ResolvedModel, + "model": modelIdentity.Model, + "requestedModel": modelIdentity.RequestedModel, + "resolvedModel": modelIdentity.ResolvedModel, + "modelName": modelIdentity.ModelName, + "modelAlias": modelIdentity.ModelAlias, + "providerModel": modelIdentity.ProviderModelName, "billings": task.Billings, "billingSummary": task.BillingSummary, } @@ -814,7 +819,7 @@ func roundMoney(value float64) float64 { func taskBillingString(value any) string { if text, ok := value.(string); ok { - return text + return strings.TrimSpace(text) } return "" } diff --git a/apps/api/internal/store/wallet.go b/apps/api/internal/store/wallet.go index 6feb148..e3cadff 100644 --- a/apps/api/internal/store/wallet.go +++ b/apps/api/internal/store/wallet.go @@ -203,14 +203,27 @@ SET frozen_balance = $2, WHERE id = $1::uuid`, locked.ID, frozenAfter); err != nil { return err } + modelIdentity := taskBillingModelIdentity(GatewayTask{ + Kind: task.Kind, + Model: task.Model, + RequestedModel: task.RequestedModel, + ResolvedModel: task.ResolvedModel, + Billings: billings, + Metrics: task.Metrics, + }) metadata, _ := json.Marshal(map[string]any{ - "taskId": taskID, - "kind": task.Kind, - "model": task.Model, - "reserved": amount, - "balance": roundMoney(locked.Balance), - "frozenBefore": roundMoney(locked.FrozenBalance), - "frozenAfter": frozenAfter, + "taskId": taskID, + "kind": task.Kind, + "model": modelIdentity.Model, + "requestedModel": modelIdentity.RequestedModel, + "resolvedModel": modelIdentity.ResolvedModel, + "modelName": modelIdentity.ModelName, + "modelAlias": modelIdentity.ModelAlias, + "providerModel": modelIdentity.ProviderModelName, + "reserved": amount, + "balance": roundMoney(locked.Balance), + "frozenBefore": roundMoney(locked.FrozenBalance), + "frozenAfter": frozenAfter, }) if _, err := tx.Exec(ctx, ` INSERT INTO gateway_wallet_transactions ( @@ -484,9 +497,11 @@ SELECT t.id::text, t.account_id::text, a.currency, COALESCE(t.gateway_tenant_id: t.metadata || jsonb_strip_nulls(jsonb_build_object( 'taskId', task.id::text, 'kind', task.kind, - 'model', task.model, + 'model', COALESCE(NULLIF(platform_model.model_alias, ''), NULLIF(task.metrics->>'modelAlias', ''), NULLIF(task.resolved_model, ''), NULLIF(platform_model.model_name, ''), NULLIF(task.metrics->>'modelName', ''), task.model), 'requestedModel', task.requested_model, - 'resolvedModel', task.resolved_model, + 'resolvedModel', COALESCE(NULLIF(platform_model.model_alias, ''), NULLIF(task.metrics->>'modelAlias', ''), NULLIF(task.resolved_model, ''), NULLIF(platform_model.model_name, ''), NULLIF(task.metrics->>'modelName', '')), + 'modelName', COALESCE(NULLIF(platform_model.model_name, ''), NULLIF(task.metrics->>'modelName', ''), NULLIF(task.resolved_model, '')), + 'modelAlias', COALESCE(NULLIF(platform_model.model_alias, ''), NULLIF(task.metrics->>'modelAlias', '')), 'modelType', task.model_type, 'taskStatus', task.status, 'runMode', task.run_mode, @@ -811,6 +826,54 @@ func walletFloat(value any) float64 { } } +type billingModelIdentity struct { + Model string + RequestedModel string + ResolvedModel string + ModelName string + ModelAlias string + ProviderModelName string +} + +func taskBillingModelIdentity(task GatewayTask) billingModelIdentity { + modelAlias := firstNonEmpty( + taskBillingString(task.Metrics["modelAlias"]), + firstBillingLineString(task.Billings, "modelAlias"), + ) + modelName := firstNonEmpty( + taskBillingString(task.Metrics["modelName"]), + taskBillingString(task.Metrics["resolvedModel"]), + task.ResolvedModel, + firstBillingLineString(task.Billings, "model"), + ) + providerModelName := firstNonEmpty( + taskBillingString(task.Metrics["providerModel"]), + firstBillingLineString(task.Billings, "providerModel"), + ) + systemModel := firstNonEmpty(modelAlias, modelName, task.ResolvedModel, task.Model) + return billingModelIdentity{ + Model: systemModel, + RequestedModel: firstNonEmpty(task.RequestedModel, task.Model), + ResolvedModel: systemModel, + ModelName: modelName, + ModelAlias: modelAlias, + ProviderModelName: providerModelName, + } +} + +func firstBillingLineString(billings []any, key string) string { + for _, raw := range billings { + line, _ := raw.(map[string]any) + if line == nil { + continue + } + if value := strings.TrimSpace(taskBillingString(line[key])); value != "" { + return value + } + } + return "" +} + func normalizeWalletCurrency(currency string) string { currency = strings.TrimSpace(currency) if currency == "" {