easyai-ai-gateway/apps/api/internal/store/candidates.go

445 lines
16 KiB
Go

package store
import (
"context"
"fmt"
"sort"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
)
func (s *Store) ListModelCandidates(ctx context.Context, model string, modelType string, user *auth.User) ([]RuntimeModelCandidate, error) {
rows, err := s.pool.Query(ctx, `
SELECT p.id::text, p.platform_key, p.name, p.provider,
COALESCE(NULLIF(p.config->>'specType', ''), NULLIF(cp.provider_type, ''), NULLIF(p.config->>'sourceSpecType', ''), p.provider) AS spec_type,
COALESCE(p.base_url, ''),
p.auth_type, p.credentials, p.config, p.default_pricing_mode,
p.default_discount_factor::float8, COALESCE(p.pricing_rule_set_id::text, ''),
p.retry_policy, p.rate_limit_policy,
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''),
$2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
COALESCE(b.pricing_rule_set_id::text, ''),
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb),
COALESCE(con.active, 0)::float8,
COALESCE(queued.waiting, 0)::float8,
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8,
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8,
COALESCE(s.running_count, 0)::float8,
COALESCE(s.waiting_count, 0)::float8,
COALESCE(s.limiter_ratio, 0)::float8,
COALESCE(EXTRACT(EPOCH FROM s.last_assigned_at), 0)::float8
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN runtime_client_states s
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
LEFT JOIN (
SELECT scope_key, SUM(lease_value) AS active
FROM gateway_concurrency_leases
WHERE scope_type = 'platform_model'
AND released_at IS NULL
AND expires_at > now()
GROUP BY scope_key
) con ON con.scope_key = m.id::text
LEFT JOIN (
SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting
FROM (
SELECT t.id::text AS task_id, qm.id::text AS platform_model_id
FROM gateway_tasks t
JOIN integration_platforms qp ON TRUE
JOIN platform_models qm ON qm.platform_id = qp.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND NULLIF(t.model_type, '') IS NOT NULL
AND qm.model_type @> jsonb_build_array(t.model_type)
AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name)
AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id)
UNION ALL
SELECT latest.task_id::text AS task_id, latest.platform_model_id
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
) queued_sources
GROUP BY queued_sources.platform_model_id
) queued ON queued.platform_model_id = m.id::text
LEFT JOIN (
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value
FROM gateway_rate_limit_counters
WHERE scope_type = 'platform_model'
AND metric = 'rpm'
AND reset_at > now()
ORDER BY scope_key, window_start DESC
) rpm ON rpm.scope_key = m.id::text
LEFT JOIN (
SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value
FROM gateway_rate_limit_counters
WHERE scope_type = 'platform_model'
AND metric LIKE 'tpm%'
AND reset_at > now()
GROUP BY scope_key
) tpm ON tpm.scope_key = m.id::text
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND m.model_type @> jsonb_build_array($2::text)
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR (
COALESCE(m.model_alias, '') = ''
AND (
m.model_name = $1::text
OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1::text
)
)
)
ORDER BY effective_priority ASC,
COALESCE(s.running_count, 0) ASC,
COALESCE(s.waiting_count, 0) ASC,
COALESCE(s.last_assigned_at, to_timestamp(0)) ASC,
m.created_at ASC`, model, modelType)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]RuntimeModelCandidate, 0)
for rows.Next() {
var item RuntimeModelCandidate
var credentials []byte
var platformConfig []byte
var platformRetryPolicy []byte
var platformRateLimitPolicy []byte
var capabilities []byte
var capabilityOverride []byte
var baseBilling []byte
var billing []byte
var billingOverride []byte
var permissionConfig []byte
var modelRetryPolicy []byte
var modelRateLimitPolicy []byte
var runtimePolicyOverride []byte
var runtimeRetryPolicy []byte
var runtimeRateLimitPolicy []byte
var autoDisablePolicy []byte
var degradePolicy []byte
var concurrentActive float64
var queuedWaiting float64
var rpmUsed float64
var rpmReserved float64
var tpmUsed float64
var tpmReserved float64
var stateRunningCount float64
var stateWaitingCount float64
var stateLimiterRatio float64
var lastAssignedUnix float64
if err := rows.Scan(
&item.PlatformID,
&item.PlatformKey,
&item.PlatformName,
&item.Provider,
&item.SpecType,
&item.BaseURL,
&item.AuthType,
&credentials,
&platformConfig,
&item.DefaultPricingMode,
&item.DefaultDiscountFactor,
&item.PlatformPricingRuleSetID,
&platformRetryPolicy,
&platformRateLimitPolicy,
&item.PlatformPriority,
&item.PlatformModelID,
&item.BaseModelID,
&item.CanonicalModelKey,
&item.ProviderModelName,
&item.ModelName,
&item.ModelAlias,
&item.ModelType,
&item.DisplayName,
&capabilities,
&capabilityOverride,
&baseBilling,
&billing,
&billingOverride,
&item.PricingMode,
&item.DiscountFactor,
&item.ModelPricingRuleSetID,
&item.BasePricingRuleSetID,
&permissionConfig,
&modelRetryPolicy,
&modelRateLimitPolicy,
&item.RuntimePolicySetID,
&runtimePolicyOverride,
&runtimeRetryPolicy,
&runtimeRateLimitPolicy,
&autoDisablePolicy,
&degradePolicy,
&concurrentActive,
&queuedWaiting,
&rpmUsed,
&rpmReserved,
&tpmUsed,
&tpmReserved,
&stateRunningCount,
&stateWaitingCount,
&stateLimiterRatio,
&lastAssignedUnix,
); err != nil {
return nil, err
}
item.Credentials = decodeObject(credentials)
item.PlatformConfig = decodeObject(platformConfig)
item.PlatformRetryPolicy = decodeObject(platformRetryPolicy)
item.PlatformRateLimitPolicy = decodeObject(platformRateLimitPolicy)
item.Capabilities = decodeObject(capabilities)
item.CapabilityOverride = decodeObject(capabilityOverride)
item.BaseBillingConfig = decodeObject(baseBilling)
item.BillingConfig = decodeObject(billing)
item.BillingConfigOverride = decodeObject(billingOverride)
item.PermissionConfig = decodeObject(permissionConfig)
item.ModelRetryPolicy = decodeObject(modelRetryPolicy)
item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy)
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy)
item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy)
item.AutoDisablePolicy = decodeObject(autoDisablePolicy)
item.DegradePolicy = decodeObject(degradePolicy)
upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName)
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName)
item.QueueKey = item.ClientID
item.RunningCount = stateRunningCount
item.WaitingCount = maxFloat(queuedWaiting, stateWaitingCount)
item.LastAssignedUnix = lastAssignedUnix
applyRuntimeCandidateLoad(&item, runtimeCandidateLoadInput{
Policy: effectiveModelRateLimitPolicy(item.PlatformRateLimitPolicy, item.RuntimeRateLimitPolicy, item.RuntimePolicyOverride, item.ModelRateLimitPolicy),
ConcurrentActive: concurrentActive,
QueuedWaiting: queuedWaiting,
RPMUsed: rpmUsed,
RPMReserved: rpmReserved,
TPMUsed: tpmUsed,
TPMReserved: tpmReserved,
StateRunningCount: stateRunningCount,
StateWaitingCount: stateWaitingCount,
StateLimiterRatio: stateLimiterRatio,
})
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(items) == 0 {
if unavailableErr, err := s.modelCandidateCooldownError(ctx, model, modelType); err != nil {
return nil, err
} else if unavailableErr != nil {
return nil, unavailableErr
}
return nil, ErrNoModelCandidate
}
items, err = s.filterCandidatesByAccessRules(ctx, user, items)
if err != nil {
return nil, err
}
if len(items) == 0 {
return nil, ErrNoModelCandidate
}
sortRuntimeModelCandidates(items)
return items, nil
}
type runtimeCandidateLoadInput struct {
Policy map[string]any
ConcurrentActive float64
QueuedWaiting float64
RPMUsed float64
RPMReserved float64
TPMUsed float64
TPMReserved float64
StateRunningCount float64
StateWaitingCount float64
StateLimiterRatio float64
}
func applyRuntimeCandidateLoad(candidate *RuntimeModelCandidate, input runtimeCandidateLoadInput) {
rpmLimit := rateLimitForMetric(input.Policy, "rpm")
tpmLimitValue := tpmLimit(input.Policy)
concurrentLimit := rateLimitForMetric(input.Policy, "concurrent")
rpmCurrent := input.RPMUsed + input.RPMReserved
tpmCurrent := input.TPMUsed + input.TPMReserved
concurrentCurrent := input.ConcurrentActive + input.QueuedWaiting
metrics := RuntimeCandidateLoadMetrics{
RPMCurrent: rpmCurrent,
RPMLimit: rpmLimit,
RPMRatio: ratioIfLimited(rpmCurrent, rpmLimit),
TPMCurrent: tpmCurrent,
TPMLimit: tpmLimitValue,
TPMRatio: ratioIfLimited(tpmCurrent, tpmLimitValue),
ConcurrentCurrent: concurrentCurrent,
ConcurrentLimit: concurrentLimit,
ConcurrentRatio: ratioIfLimited(concurrentCurrent, concurrentLimit),
QueuedCount: input.QueuedWaiting,
StateRunningCount: input.StateRunningCount,
StateWaitingCount: input.StateWaitingCount,
StateLimiterRatio: input.StateLimiterRatio,
}
candidate.LoadMetrics = metrics
candidate.LoadLimited = rpmLimit > 0 || tpmLimitValue > 0 || concurrentLimit > 0
candidate.LoadRatio = maxFloat(metrics.RPMRatio, metrics.TPMRatio, metrics.ConcurrentRatio)
}
func ratioIfLimited(current float64, limit float64) float64 {
if limit <= 0 {
return 0
}
return current / limit
}
func sortRuntimeModelCandidates(items []RuntimeModelCandidate) {
hasFull := false
hasNonFull := false
for index := range items {
items[index].LoadAvoided = false
if runtimeCandidateFull(items[index]) {
hasFull = true
} else {
hasNonFull = true
}
}
if hasFull && hasNonFull {
for index := range items {
items[index].LoadAvoided = runtimeCandidateFull(items[index])
}
}
sort.SliceStable(items, func(i, j int) bool {
aFull := runtimeCandidateFull(items[i])
bFull := runtimeCandidateFull(items[j])
if aFull != bFull {
return !aFull
}
if items[i].PlatformPriority != items[j].PlatformPriority {
return items[i].PlatformPriority < items[j].PlatformPriority
}
if items[i].LoadRatio != items[j].LoadRatio {
return items[i].LoadRatio < items[j].LoadRatio
}
if items[i].RunningCount != items[j].RunningCount {
return items[i].RunningCount < items[j].RunningCount
}
if items[i].WaitingCount != items[j].WaitingCount {
return items[i].WaitingCount < items[j].WaitingCount
}
if items[i].LastAssignedUnix != items[j].LastAssignedUnix {
return items[i].LastAssignedUnix < items[j].LastAssignedUnix
}
return false
})
}
func runtimeCandidateFull(candidate RuntimeModelCandidate) bool {
return candidate.LoadLimited && candidate.LoadRatio >= 1
}
func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) {
rows, err := s.pool.Query(ctx, `
SELECT p.name,
COALESCE(NULLIF(m.display_name, ''), NULLIF(m.model_alias, ''), m.model_name),
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
GREATEST(COALESCE(EXTRACT(EPOCH FROM p.cooldown_until - now()), 0), 0)::float8,
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
GREATEST(COALESCE(EXTRACT(EPOCH FROM m.cooldown_until - now()), 0), 0)::float8
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND m.model_type @> jsonb_build_array($2::text)
AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR (
COALESCE(m.model_alias, '') = ''
AND (
m.model_name = $1::text
OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1::text
)
)
)
ORDER BY GREATEST(COALESCE(p.cooldown_until, to_timestamp(0)), COALESCE(m.cooldown_until, to_timestamp(0))) DESC,
p.priority ASC,
m.created_at ASC`, model, modelType)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var platformName string
var displayName string
var platformCooldownUntil string
var platformRemainingSeconds float64
var modelCooldownUntil string
var modelRemainingSeconds float64
if err := rows.Scan(
&platformName,
&displayName,
&platformCooldownUntil,
&platformRemainingSeconds,
&modelCooldownUntil,
&modelRemainingSeconds,
); err != nil {
return nil, err
}
if modelRemainingSeconds > 0 {
return &ModelCandidateUnavailableError{
Code: "model_cooling_down",
Message: cooldownErrorMessage("模型", displayName, modelRemainingSeconds, modelCooldownUntil),
}, nil
}
if platformRemainingSeconds > 0 {
return &ModelCandidateUnavailableError{
Code: "platform_cooling_down",
Message: cooldownErrorMessage("平台", platformName, platformRemainingSeconds, platformCooldownUntil),
}, nil
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return nil, nil
}
func cooldownErrorMessage(scope string, name string, remainingSeconds float64, cooldownUntil string) string {
name = strings.TrimSpace(name)
if name == "" {
name = "候选"
}
remainingMinutes := remainingSeconds / 60
if remainingMinutes < 0.1 {
remainingMinutes = 0.1
}
message := fmt.Sprintf("%s %s 冷却中,剩余 %.1f 分钟", scope, name, remainingMinutes)
if strings.TrimSpace(cooldownUntil) != "" {
message += ",预计恢复时间 " + cooldownUntil
}
return message
}