445 lines
16 KiB
Go
445 lines
16 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
)
|
|
|
|
func (s *Store) ListModelCandidates(ctx context.Context, model string, modelType string, user *auth.User) ([]RuntimeModelCandidate, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT p.id::text, p.platform_key, p.name, p.provider,
|
|
COALESCE(NULLIF(p.config->>'specType', ''), NULLIF(cp.provider_type, ''), NULLIF(p.config->>'sourceSpecType', ''), p.provider) AS spec_type,
|
|
COALESCE(p.base_url, ''),
|
|
p.auth_type, p.credentials, p.config, p.default_pricing_mode,
|
|
p.default_discount_factor::float8, COALESCE(p.pricing_rule_set_id::text, ''),
|
|
p.retry_policy, p.rate_limit_policy,
|
|
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
|
|
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
|
|
COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''),
|
|
$2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
|
|
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
|
|
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
|
|
COALESCE(b.pricing_rule_set_id::text, ''),
|
|
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
|
|
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
|
|
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
|
|
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb),
|
|
COALESCE(con.active, 0)::float8,
|
|
COALESCE(queued.waiting, 0)::float8,
|
|
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8,
|
|
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8,
|
|
COALESCE(s.running_count, 0)::float8,
|
|
COALESCE(s.waiting_count, 0)::float8,
|
|
COALESCE(s.limiter_ratio, 0)::float8,
|
|
COALESCE(EXTRACT(EPOCH FROM s.last_assigned_at), 0)::float8
|
|
FROM platform_models m
|
|
JOIN integration_platforms p ON p.id = m.platform_id
|
|
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
|
|
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
|
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
|
|
LEFT JOIN runtime_client_states s
|
|
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
|
|
LEFT JOIN (
|
|
SELECT scope_key, SUM(lease_value) AS active
|
|
FROM gateway_concurrency_leases
|
|
WHERE scope_type = 'platform_model'
|
|
AND released_at IS NULL
|
|
AND expires_at > now()
|
|
GROUP BY scope_key
|
|
) con ON con.scope_key = m.id::text
|
|
LEFT JOIN (
|
|
SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting
|
|
FROM (
|
|
SELECT t.id::text AS task_id, qm.id::text AS platform_model_id
|
|
FROM gateway_tasks t
|
|
JOIN integration_platforms qp ON TRUE
|
|
JOIN platform_models qm ON qm.platform_id = qp.id
|
|
WHERE t.async_mode = true
|
|
AND t.status = 'queued'
|
|
AND NULLIF(t.model_type, '') IS NOT NULL
|
|
AND qm.model_type @> jsonb_build_array(t.model_type)
|
|
AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name)
|
|
AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id)
|
|
UNION ALL
|
|
SELECT latest.task_id::text AS task_id, latest.platform_model_id
|
|
FROM (
|
|
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
|
|
FROM gateway_tasks t
|
|
JOIN gateway_task_attempts a ON a.task_id = t.id
|
|
WHERE t.async_mode = true
|
|
AND t.status = 'queued'
|
|
AND a.platform_model_id IS NOT NULL
|
|
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
|
|
) latest
|
|
) queued_sources
|
|
GROUP BY queued_sources.platform_model_id
|
|
) queued ON queued.platform_model_id = m.id::text
|
|
LEFT JOIN (
|
|
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value
|
|
FROM gateway_rate_limit_counters
|
|
WHERE scope_type = 'platform_model'
|
|
AND metric = 'rpm'
|
|
AND reset_at > now()
|
|
ORDER BY scope_key, window_start DESC
|
|
) rpm ON rpm.scope_key = m.id::text
|
|
LEFT JOIN (
|
|
SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value
|
|
FROM gateway_rate_limit_counters
|
|
WHERE scope_type = 'platform_model'
|
|
AND metric LIKE 'tpm%'
|
|
AND reset_at > now()
|
|
GROUP BY scope_key
|
|
) tpm ON tpm.scope_key = m.id::text
|
|
WHERE p.status = 'enabled'
|
|
AND p.deleted_at IS NULL
|
|
AND m.enabled = true
|
|
AND m.model_type @> jsonb_build_array($2::text)
|
|
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
|
|
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
|
|
AND (
|
|
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
|
|
OR (
|
|
COALESCE(m.model_alias, '') = ''
|
|
AND (
|
|
m.model_name = $1::text
|
|
OR b.canonical_model_key = $1::text
|
|
OR b.provider_model_name = $1::text
|
|
)
|
|
)
|
|
)
|
|
ORDER BY effective_priority ASC,
|
|
COALESCE(s.running_count, 0) ASC,
|
|
COALESCE(s.waiting_count, 0) ASC,
|
|
COALESCE(s.last_assigned_at, to_timestamp(0)) ASC,
|
|
m.created_at ASC`, model, modelType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]RuntimeModelCandidate, 0)
|
|
for rows.Next() {
|
|
var item RuntimeModelCandidate
|
|
var credentials []byte
|
|
var platformConfig []byte
|
|
var platformRetryPolicy []byte
|
|
var platformRateLimitPolicy []byte
|
|
var capabilities []byte
|
|
var capabilityOverride []byte
|
|
var baseBilling []byte
|
|
var billing []byte
|
|
var billingOverride []byte
|
|
var permissionConfig []byte
|
|
var modelRetryPolicy []byte
|
|
var modelRateLimitPolicy []byte
|
|
var runtimePolicyOverride []byte
|
|
var runtimeRetryPolicy []byte
|
|
var runtimeRateLimitPolicy []byte
|
|
var autoDisablePolicy []byte
|
|
var degradePolicy []byte
|
|
var concurrentActive float64
|
|
var queuedWaiting float64
|
|
var rpmUsed float64
|
|
var rpmReserved float64
|
|
var tpmUsed float64
|
|
var tpmReserved float64
|
|
var stateRunningCount float64
|
|
var stateWaitingCount float64
|
|
var stateLimiterRatio float64
|
|
var lastAssignedUnix float64
|
|
if err := rows.Scan(
|
|
&item.PlatformID,
|
|
&item.PlatformKey,
|
|
&item.PlatformName,
|
|
&item.Provider,
|
|
&item.SpecType,
|
|
&item.BaseURL,
|
|
&item.AuthType,
|
|
&credentials,
|
|
&platformConfig,
|
|
&item.DefaultPricingMode,
|
|
&item.DefaultDiscountFactor,
|
|
&item.PlatformPricingRuleSetID,
|
|
&platformRetryPolicy,
|
|
&platformRateLimitPolicy,
|
|
&item.PlatformPriority,
|
|
&item.PlatformModelID,
|
|
&item.BaseModelID,
|
|
&item.CanonicalModelKey,
|
|
&item.ProviderModelName,
|
|
&item.ModelName,
|
|
&item.ModelAlias,
|
|
&item.ModelType,
|
|
&item.DisplayName,
|
|
&capabilities,
|
|
&capabilityOverride,
|
|
&baseBilling,
|
|
&billing,
|
|
&billingOverride,
|
|
&item.PricingMode,
|
|
&item.DiscountFactor,
|
|
&item.ModelPricingRuleSetID,
|
|
&item.BasePricingRuleSetID,
|
|
&permissionConfig,
|
|
&modelRetryPolicy,
|
|
&modelRateLimitPolicy,
|
|
&item.RuntimePolicySetID,
|
|
&runtimePolicyOverride,
|
|
&runtimeRetryPolicy,
|
|
&runtimeRateLimitPolicy,
|
|
&autoDisablePolicy,
|
|
°radePolicy,
|
|
&concurrentActive,
|
|
&queuedWaiting,
|
|
&rpmUsed,
|
|
&rpmReserved,
|
|
&tpmUsed,
|
|
&tpmReserved,
|
|
&stateRunningCount,
|
|
&stateWaitingCount,
|
|
&stateLimiterRatio,
|
|
&lastAssignedUnix,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
item.Credentials = decodeObject(credentials)
|
|
item.PlatformConfig = decodeObject(platformConfig)
|
|
item.PlatformRetryPolicy = decodeObject(platformRetryPolicy)
|
|
item.PlatformRateLimitPolicy = decodeObject(platformRateLimitPolicy)
|
|
item.Capabilities = decodeObject(capabilities)
|
|
item.CapabilityOverride = decodeObject(capabilityOverride)
|
|
item.BaseBillingConfig = decodeObject(baseBilling)
|
|
item.BillingConfig = decodeObject(billing)
|
|
item.BillingConfigOverride = decodeObject(billingOverride)
|
|
item.PermissionConfig = decodeObject(permissionConfig)
|
|
item.ModelRetryPolicy = decodeObject(modelRetryPolicy)
|
|
item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy)
|
|
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
|
|
item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy)
|
|
item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy)
|
|
item.AutoDisablePolicy = decodeObject(autoDisablePolicy)
|
|
item.DegradePolicy = decodeObject(degradePolicy)
|
|
upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName)
|
|
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName)
|
|
item.QueueKey = item.ClientID
|
|
item.RunningCount = stateRunningCount
|
|
item.WaitingCount = maxFloat(queuedWaiting, stateWaitingCount)
|
|
item.LastAssignedUnix = lastAssignedUnix
|
|
applyRuntimeCandidateLoad(&item, runtimeCandidateLoadInput{
|
|
Policy: effectiveModelRateLimitPolicy(item.PlatformRateLimitPolicy, item.RuntimeRateLimitPolicy, item.RuntimePolicyOverride, item.ModelRateLimitPolicy),
|
|
ConcurrentActive: concurrentActive,
|
|
QueuedWaiting: queuedWaiting,
|
|
RPMUsed: rpmUsed,
|
|
RPMReserved: rpmReserved,
|
|
TPMUsed: tpmUsed,
|
|
TPMReserved: tpmReserved,
|
|
StateRunningCount: stateRunningCount,
|
|
StateWaitingCount: stateWaitingCount,
|
|
StateLimiterRatio: stateLimiterRatio,
|
|
})
|
|
items = append(items, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(items) == 0 {
|
|
if unavailableErr, err := s.modelCandidateCooldownError(ctx, model, modelType); err != nil {
|
|
return nil, err
|
|
} else if unavailableErr != nil {
|
|
return nil, unavailableErr
|
|
}
|
|
return nil, ErrNoModelCandidate
|
|
}
|
|
items, err = s.filterCandidatesByAccessRules(ctx, user, items)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(items) == 0 {
|
|
return nil, ErrNoModelCandidate
|
|
}
|
|
sortRuntimeModelCandidates(items)
|
|
return items, nil
|
|
}
|
|
|
|
type runtimeCandidateLoadInput struct {
|
|
Policy map[string]any
|
|
ConcurrentActive float64
|
|
QueuedWaiting float64
|
|
RPMUsed float64
|
|
RPMReserved float64
|
|
TPMUsed float64
|
|
TPMReserved float64
|
|
StateRunningCount float64
|
|
StateWaitingCount float64
|
|
StateLimiterRatio float64
|
|
}
|
|
|
|
func applyRuntimeCandidateLoad(candidate *RuntimeModelCandidate, input runtimeCandidateLoadInput) {
|
|
rpmLimit := rateLimitForMetric(input.Policy, "rpm")
|
|
tpmLimitValue := tpmLimit(input.Policy)
|
|
concurrentLimit := rateLimitForMetric(input.Policy, "concurrent")
|
|
rpmCurrent := input.RPMUsed + input.RPMReserved
|
|
tpmCurrent := input.TPMUsed + input.TPMReserved
|
|
concurrentCurrent := input.ConcurrentActive + input.QueuedWaiting
|
|
metrics := RuntimeCandidateLoadMetrics{
|
|
RPMCurrent: rpmCurrent,
|
|
RPMLimit: rpmLimit,
|
|
RPMRatio: ratioIfLimited(rpmCurrent, rpmLimit),
|
|
TPMCurrent: tpmCurrent,
|
|
TPMLimit: tpmLimitValue,
|
|
TPMRatio: ratioIfLimited(tpmCurrent, tpmLimitValue),
|
|
ConcurrentCurrent: concurrentCurrent,
|
|
ConcurrentLimit: concurrentLimit,
|
|
ConcurrentRatio: ratioIfLimited(concurrentCurrent, concurrentLimit),
|
|
QueuedCount: input.QueuedWaiting,
|
|
StateRunningCount: input.StateRunningCount,
|
|
StateWaitingCount: input.StateWaitingCount,
|
|
StateLimiterRatio: input.StateLimiterRatio,
|
|
}
|
|
candidate.LoadMetrics = metrics
|
|
candidate.LoadLimited = rpmLimit > 0 || tpmLimitValue > 0 || concurrentLimit > 0
|
|
candidate.LoadRatio = maxFloat(metrics.RPMRatio, metrics.TPMRatio, metrics.ConcurrentRatio)
|
|
}
|
|
|
|
func ratioIfLimited(current float64, limit float64) float64 {
|
|
if limit <= 0 {
|
|
return 0
|
|
}
|
|
return current / limit
|
|
}
|
|
|
|
func sortRuntimeModelCandidates(items []RuntimeModelCandidate) {
|
|
hasFull := false
|
|
hasNonFull := false
|
|
for index := range items {
|
|
items[index].LoadAvoided = false
|
|
if runtimeCandidateFull(items[index]) {
|
|
hasFull = true
|
|
} else {
|
|
hasNonFull = true
|
|
}
|
|
}
|
|
if hasFull && hasNonFull {
|
|
for index := range items {
|
|
items[index].LoadAvoided = runtimeCandidateFull(items[index])
|
|
}
|
|
}
|
|
sort.SliceStable(items, func(i, j int) bool {
|
|
aFull := runtimeCandidateFull(items[i])
|
|
bFull := runtimeCandidateFull(items[j])
|
|
if aFull != bFull {
|
|
return !aFull
|
|
}
|
|
if items[i].PlatformPriority != items[j].PlatformPriority {
|
|
return items[i].PlatformPriority < items[j].PlatformPriority
|
|
}
|
|
if items[i].LoadRatio != items[j].LoadRatio {
|
|
return items[i].LoadRatio < items[j].LoadRatio
|
|
}
|
|
if items[i].RunningCount != items[j].RunningCount {
|
|
return items[i].RunningCount < items[j].RunningCount
|
|
}
|
|
if items[i].WaitingCount != items[j].WaitingCount {
|
|
return items[i].WaitingCount < items[j].WaitingCount
|
|
}
|
|
if items[i].LastAssignedUnix != items[j].LastAssignedUnix {
|
|
return items[i].LastAssignedUnix < items[j].LastAssignedUnix
|
|
}
|
|
return false
|
|
})
|
|
}
|
|
|
|
func runtimeCandidateFull(candidate RuntimeModelCandidate) bool {
|
|
return candidate.LoadLimited && candidate.LoadRatio >= 1
|
|
}
|
|
|
|
func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT p.name,
|
|
COALESCE(NULLIF(m.display_name, ''), NULLIF(m.model_alias, ''), m.model_name),
|
|
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
|
|
GREATEST(COALESCE(EXTRACT(EPOCH FROM p.cooldown_until - now()), 0), 0)::float8,
|
|
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
|
|
GREATEST(COALESCE(EXTRACT(EPOCH FROM m.cooldown_until - now()), 0), 0)::float8
|
|
FROM platform_models m
|
|
JOIN integration_platforms p ON p.id = m.platform_id
|
|
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
|
WHERE p.status = 'enabled'
|
|
AND p.deleted_at IS NULL
|
|
AND m.enabled = true
|
|
AND m.model_type @> jsonb_build_array($2::text)
|
|
AND (
|
|
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
|
|
OR (
|
|
COALESCE(m.model_alias, '') = ''
|
|
AND (
|
|
m.model_name = $1::text
|
|
OR b.canonical_model_key = $1::text
|
|
OR b.provider_model_name = $1::text
|
|
)
|
|
)
|
|
)
|
|
ORDER BY GREATEST(COALESCE(p.cooldown_until, to_timestamp(0)), COALESCE(m.cooldown_until, to_timestamp(0))) DESC,
|
|
p.priority ASC,
|
|
m.created_at ASC`, model, modelType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var platformName string
|
|
var displayName string
|
|
var platformCooldownUntil string
|
|
var platformRemainingSeconds float64
|
|
var modelCooldownUntil string
|
|
var modelRemainingSeconds float64
|
|
if err := rows.Scan(
|
|
&platformName,
|
|
&displayName,
|
|
&platformCooldownUntil,
|
|
&platformRemainingSeconds,
|
|
&modelCooldownUntil,
|
|
&modelRemainingSeconds,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
if modelRemainingSeconds > 0 {
|
|
return &ModelCandidateUnavailableError{
|
|
Code: "model_cooling_down",
|
|
Message: cooldownErrorMessage("模型", displayName, modelRemainingSeconds, modelCooldownUntil),
|
|
}, nil
|
|
}
|
|
if platformRemainingSeconds > 0 {
|
|
return &ModelCandidateUnavailableError{
|
|
Code: "platform_cooling_down",
|
|
Message: cooldownErrorMessage("平台", platformName, platformRemainingSeconds, platformCooldownUntil),
|
|
}, nil
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func cooldownErrorMessage(scope string, name string, remainingSeconds float64, cooldownUntil string) string {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
name = "候选"
|
|
}
|
|
remainingMinutes := remainingSeconds / 60
|
|
if remainingMinutes < 0.1 {
|
|
remainingMinutes = 0.1
|
|
}
|
|
message := fmt.Sprintf("%s %s 冷却中,剩余 %.1f 分钟", scope, name, remainingMinutes)
|
|
if strings.TrimSpace(cooldownUntil) != "" {
|
|
message += ",预计恢复时间 " + cooldownUntil
|
|
}
|
|
return message
|
|
}
|