package store import ( "context" "fmt" "sort" "strings" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" ) func (s *Store) ListModelCandidates(ctx context.Context, model string, modelType string, user *auth.User) ([]RuntimeModelCandidate, error) { rows, err := s.pool.Query(ctx, ` SELECT p.id::text, p.platform_key, p.name, p.provider, COALESCE(NULLIF(p.config->>'specType', ''), NULLIF(cp.provider_type, ''), NULLIF(p.config->>'sourceSpecType', ''), p.provider) AS spec_type, COALESCE(p.base_url, ''), p.auth_type, p.credentials, p.config, p.default_pricing_mode, p.default_discount_factor::float8, COALESCE(p.pricing_rule_set_id::text, ''), p.retry_policy, p.rate_limit_policy, COALESCE(p.dynamic_priority, p.priority) AS effective_priority, m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''), COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''), $2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override, COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''), COALESCE(b.pricing_rule_set_id::text, ''), m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb), COALESCE(con.active, 0)::float8, COALESCE(queued.waiting, 0)::float8, COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(s.running_count, 0)::float8, COALESCE(s.waiting_count, 0)::float8, COALESCE(s.limiter_ratio, 0)::float8, COALESCE(EXTRACT(EPOCH FROM s.last_assigned_at), 0)::float8 FROM platform_models m JOIN integration_platforms p ON p.id = m.platform_id LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider LEFT JOIN base_model_catalog b ON b.id = m.base_model_id LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id) LEFT JOIN runtime_client_states s ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) LEFT JOIN ( SELECT scope_key, SUM(lease_value) AS active FROM gateway_concurrency_leases WHERE scope_type = 'platform_model' AND released_at IS NULL AND expires_at > now() GROUP BY scope_key ) con ON con.scope_key = m.id::text LEFT JOIN ( SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting FROM ( SELECT t.id::text AS task_id, qm.id::text AS platform_model_id FROM gateway_tasks t JOIN integration_platforms qp ON TRUE JOIN platform_models qm ON qm.platform_id = qp.id WHERE t.async_mode = true AND t.status = 'queued' AND NULLIF(t.model_type, '') IS NOT NULL AND qm.model_type @> jsonb_build_array(t.model_type) AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name) AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id) UNION ALL SELECT latest.task_id::text AS task_id, latest.platform_model_id FROM ( SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id FROM gateway_tasks t JOIN gateway_task_attempts a ON a.task_id = t.id WHERE t.async_mode = true AND t.status = 'queued' AND a.platform_model_id IS NOT NULL ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC ) latest ) queued_sources GROUP BY queued_sources.platform_model_id ) queued ON queued.platform_model_id = m.id::text LEFT JOIN ( SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value FROM gateway_rate_limit_counters WHERE scope_type = 'platform_model' AND metric = 'rpm' AND reset_at > now() ORDER BY scope_key, window_start DESC ) rpm ON rpm.scope_key = m.id::text LEFT JOIN ( SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value FROM gateway_rate_limit_counters WHERE scope_type = 'platform_model' AND metric LIKE 'tpm%' AND reset_at > now() GROUP BY scope_key ) tpm ON tpm.scope_key = m.id::text WHERE p.status = 'enabled' AND p.deleted_at IS NULL AND m.enabled = true AND m.model_type @> jsonb_build_array($2::text) AND (p.cooldown_until IS NULL OR p.cooldown_until <= now()) AND (m.cooldown_until IS NULL OR m.cooldown_until <= now()) AND ( (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text) OR ( COALESCE(m.model_alias, '') = '' AND ( m.model_name = $1::text OR b.canonical_model_key = $1::text OR b.provider_model_name = $1::text ) ) ) ORDER BY effective_priority ASC, COALESCE(s.running_count, 0) ASC, COALESCE(s.waiting_count, 0) ASC, COALESCE(s.last_assigned_at, to_timestamp(0)) ASC, m.created_at ASC`, model, modelType) if err != nil { return nil, err } defer rows.Close() items := make([]RuntimeModelCandidate, 0) for rows.Next() { var item RuntimeModelCandidate var credentials []byte var platformConfig []byte var platformRetryPolicy []byte var platformRateLimitPolicy []byte var capabilities []byte var capabilityOverride []byte var baseBilling []byte var billing []byte var billingOverride []byte var permissionConfig []byte var modelRetryPolicy []byte var modelRateLimitPolicy []byte var runtimePolicyOverride []byte var runtimeRetryPolicy []byte var runtimeRateLimitPolicy []byte var autoDisablePolicy []byte var degradePolicy []byte var concurrentActive float64 var queuedWaiting float64 var rpmUsed float64 var rpmReserved float64 var tpmUsed float64 var tpmReserved float64 var stateRunningCount float64 var stateWaitingCount float64 var stateLimiterRatio float64 var lastAssignedUnix float64 if err := rows.Scan( &item.PlatformID, &item.PlatformKey, &item.PlatformName, &item.Provider, &item.SpecType, &item.BaseURL, &item.AuthType, &credentials, &platformConfig, &item.DefaultPricingMode, &item.DefaultDiscountFactor, &item.PlatformPricingRuleSetID, &platformRetryPolicy, &platformRateLimitPolicy, &item.PlatformPriority, &item.PlatformModelID, &item.BaseModelID, &item.CanonicalModelKey, &item.ProviderModelName, &item.ModelName, &item.ModelAlias, &item.ModelType, &item.DisplayName, &capabilities, &capabilityOverride, &baseBilling, &billing, &billingOverride, &item.PricingMode, &item.DiscountFactor, &item.ModelPricingRuleSetID, &item.BasePricingRuleSetID, &permissionConfig, &modelRetryPolicy, &modelRateLimitPolicy, &item.RuntimePolicySetID, &runtimePolicyOverride, &runtimeRetryPolicy, &runtimeRateLimitPolicy, &autoDisablePolicy, °radePolicy, &concurrentActive, &queuedWaiting, &rpmUsed, &rpmReserved, &tpmUsed, &tpmReserved, &stateRunningCount, &stateWaitingCount, &stateLimiterRatio, &lastAssignedUnix, ); err != nil { return nil, err } item.Credentials = decodeObject(credentials) item.PlatformConfig = decodeObject(platformConfig) item.PlatformRetryPolicy = decodeObject(platformRetryPolicy) item.PlatformRateLimitPolicy = decodeObject(platformRateLimitPolicy) item.Capabilities = decodeObject(capabilities) item.CapabilityOverride = decodeObject(capabilityOverride) item.BaseBillingConfig = decodeObject(baseBilling) item.BillingConfig = decodeObject(billing) item.BillingConfigOverride = decodeObject(billingOverride) item.PermissionConfig = decodeObject(permissionConfig) item.ModelRetryPolicy = decodeObject(modelRetryPolicy) item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy) item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride) item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy) item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy) item.AutoDisablePolicy = decodeObject(autoDisablePolicy) item.DegradePolicy = decodeObject(degradePolicy) upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName) item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName) item.QueueKey = item.ClientID item.RunningCount = stateRunningCount item.WaitingCount = maxFloat(queuedWaiting, stateWaitingCount) item.LastAssignedUnix = lastAssignedUnix applyRuntimeCandidateLoad(&item, runtimeCandidateLoadInput{ Policy: effectiveModelRateLimitPolicy(item.PlatformRateLimitPolicy, item.RuntimeRateLimitPolicy, item.RuntimePolicyOverride, item.ModelRateLimitPolicy), ConcurrentActive: concurrentActive, QueuedWaiting: queuedWaiting, RPMUsed: rpmUsed, RPMReserved: rpmReserved, TPMUsed: tpmUsed, TPMReserved: tpmReserved, StateRunningCount: stateRunningCount, StateWaitingCount: stateWaitingCount, StateLimiterRatio: stateLimiterRatio, }) items = append(items, item) } if err := rows.Err(); err != nil { return nil, err } if len(items) == 0 { if unavailableErr, err := s.modelCandidateCooldownError(ctx, model, modelType); err != nil { return nil, err } else if unavailableErr != nil { return nil, unavailableErr } return nil, ErrNoModelCandidate } items, err = s.filterCandidatesByAccessRules(ctx, user, items) if err != nil { return nil, err } if len(items) == 0 { return nil, ErrNoModelCandidate } sortRuntimeModelCandidates(items) return items, nil } type runtimeCandidateLoadInput struct { Policy map[string]any ConcurrentActive float64 QueuedWaiting float64 RPMUsed float64 RPMReserved float64 TPMUsed float64 TPMReserved float64 StateRunningCount float64 StateWaitingCount float64 StateLimiterRatio float64 } func applyRuntimeCandidateLoad(candidate *RuntimeModelCandidate, input runtimeCandidateLoadInput) { rpmLimit := rateLimitForMetric(input.Policy, "rpm") tpmLimitValue := tpmLimit(input.Policy) concurrentLimit := rateLimitForMetric(input.Policy, "concurrent") rpmCurrent := input.RPMUsed + input.RPMReserved tpmCurrent := input.TPMUsed + input.TPMReserved concurrentCurrent := input.ConcurrentActive + input.QueuedWaiting metrics := RuntimeCandidateLoadMetrics{ RPMCurrent: rpmCurrent, RPMLimit: rpmLimit, RPMRatio: ratioIfLimited(rpmCurrent, rpmLimit), TPMCurrent: tpmCurrent, TPMLimit: tpmLimitValue, TPMRatio: ratioIfLimited(tpmCurrent, tpmLimitValue), ConcurrentCurrent: concurrentCurrent, ConcurrentLimit: concurrentLimit, ConcurrentRatio: ratioIfLimited(concurrentCurrent, concurrentLimit), QueuedCount: input.QueuedWaiting, StateRunningCount: input.StateRunningCount, StateWaitingCount: input.StateWaitingCount, StateLimiterRatio: input.StateLimiterRatio, } candidate.LoadMetrics = metrics candidate.LoadLimited = rpmLimit > 0 || tpmLimitValue > 0 || concurrentLimit > 0 candidate.LoadRatio = maxFloat(metrics.RPMRatio, metrics.TPMRatio, metrics.ConcurrentRatio) } func ratioIfLimited(current float64, limit float64) float64 { if limit <= 0 { return 0 } return current / limit } func sortRuntimeModelCandidates(items []RuntimeModelCandidate) { hasFull := false hasNonFull := false for index := range items { items[index].LoadAvoided = false if runtimeCandidateFull(items[index]) { hasFull = true } else { hasNonFull = true } } if hasFull && hasNonFull { for index := range items { items[index].LoadAvoided = runtimeCandidateFull(items[index]) } } sort.SliceStable(items, func(i, j int) bool { aFull := runtimeCandidateFull(items[i]) bFull := runtimeCandidateFull(items[j]) if aFull != bFull { return !aFull } if items[i].PlatformPriority != items[j].PlatformPriority { return items[i].PlatformPriority < items[j].PlatformPriority } if items[i].LoadRatio != items[j].LoadRatio { return items[i].LoadRatio < items[j].LoadRatio } if items[i].RunningCount != items[j].RunningCount { return items[i].RunningCount < items[j].RunningCount } if items[i].WaitingCount != items[j].WaitingCount { return items[i].WaitingCount < items[j].WaitingCount } if items[i].LastAssignedUnix != items[j].LastAssignedUnix { return items[i].LastAssignedUnix < items[j].LastAssignedUnix } return false }) } func runtimeCandidateFull(candidate RuntimeModelCandidate) bool { return candidate.LoadLimited && candidate.LoadRatio >= 1 } func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) { rows, err := s.pool.Query(ctx, ` SELECT p.name, COALESCE(NULLIF(m.display_name, ''), NULLIF(m.model_alias, ''), m.model_name), COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), GREATEST(COALESCE(EXTRACT(EPOCH FROM p.cooldown_until - now()), 0), 0)::float8, COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''), GREATEST(COALESCE(EXTRACT(EPOCH FROM m.cooldown_until - now()), 0), 0)::float8 FROM platform_models m JOIN integration_platforms p ON p.id = m.platform_id LEFT JOIN base_model_catalog b ON b.id = m.base_model_id WHERE p.status = 'enabled' AND p.deleted_at IS NULL AND m.enabled = true AND m.model_type @> jsonb_build_array($2::text) AND ( (COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text) OR ( COALESCE(m.model_alias, '') = '' AND ( m.model_name = $1::text OR b.canonical_model_key = $1::text OR b.provider_model_name = $1::text ) ) ) ORDER BY GREATEST(COALESCE(p.cooldown_until, to_timestamp(0)), COALESCE(m.cooldown_until, to_timestamp(0))) DESC, p.priority ASC, m.created_at ASC`, model, modelType) if err != nil { return nil, err } defer rows.Close() for rows.Next() { var platformName string var displayName string var platformCooldownUntil string var platformRemainingSeconds float64 var modelCooldownUntil string var modelRemainingSeconds float64 if err := rows.Scan( &platformName, &displayName, &platformCooldownUntil, &platformRemainingSeconds, &modelCooldownUntil, &modelRemainingSeconds, ); err != nil { return nil, err } if modelRemainingSeconds > 0 { return &ModelCandidateUnavailableError{ Code: "model_cooling_down", Message: cooldownErrorMessage("模型", displayName, modelRemainingSeconds, modelCooldownUntil), }, nil } if platformRemainingSeconds > 0 { return &ModelCandidateUnavailableError{ Code: "platform_cooling_down", Message: cooldownErrorMessage("平台", platformName, platformRemainingSeconds, platformCooldownUntil), }, nil } } if err := rows.Err(); err != nil { return nil, err } return nil, nil } func cooldownErrorMessage(scope string, name string, remainingSeconds float64, cooldownUntil string) string { name = strings.TrimSpace(name) if name == "" { name = "候选" } remainingMinutes := remainingSeconds / 60 if remainingMinutes < 0.1 { remainingMinutes = 0.1 } message := fmt.Sprintf("%s %s 冷却中,剩余 %.1f 分钟", scope, name, remainingMinutes) if strings.TrimSpace(cooldownUntil) != "" { message += ",预计恢复时间 " + cooldownUntil } return message }