easyai-ai-gateway/apps/api/internal/store/candidates.go

258 lines
9.1 KiB
Go

package store
import (
"context"
"fmt"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
)
func (s *Store) ListModelCandidates(ctx context.Context, model string, modelType string, user *auth.User) ([]RuntimeModelCandidate, error) {
rows, err := s.pool.Query(ctx, `
SELECT p.id::text, p.platform_key, p.name, p.provider,
COALESCE(NULLIF(p.config->>'specType', ''), NULLIF(cp.provider_type, ''), NULLIF(p.config->>'sourceSpecType', ''), p.provider) AS spec_type,
COALESCE(p.base_url, ''),
p.auth_type, p.credentials, p.config, p.default_pricing_mode,
p.default_discount_factor::float8, COALESCE(p.pricing_rule_set_id::text, ''),
p.retry_policy, p.rate_limit_policy,
COALESCE(p.dynamic_priority, p.priority) AS effective_priority,
m.id::text, COALESCE(m.base_model_id::text, ''), COALESCE(b.canonical_model_key, ''),
COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), m.model_name, COALESCE(m.model_alias, ''),
$2::text AS requested_model_type, m.display_name, m.capabilities, m.capability_override,
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
COALESCE(b.pricing_rule_set_id::text, ''),
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb)
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN runtime_client_states s
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND m.model_type @> jsonb_build_array($2::text)
AND (p.cooldown_until IS NULL OR p.cooldown_until <= now())
AND (m.cooldown_until IS NULL OR m.cooldown_until <= now())
AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR (
COALESCE(m.model_alias, '') = ''
AND (
m.model_name = $1::text
OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1::text
)
)
)
ORDER BY effective_priority ASC,
COALESCE(s.limiter_ratio, 0) ASC,
COALESCE(s.running_count, 0) ASC,
COALESCE(s.waiting_count, 0) ASC,
COALESCE(s.last_assigned_at, to_timestamp(0)) ASC,
m.created_at ASC`, model, modelType)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]RuntimeModelCandidate, 0)
for rows.Next() {
var item RuntimeModelCandidate
var credentials []byte
var platformConfig []byte
var platformRetryPolicy []byte
var platformRateLimitPolicy []byte
var capabilities []byte
var capabilityOverride []byte
var baseBilling []byte
var billing []byte
var billingOverride []byte
var permissionConfig []byte
var modelRetryPolicy []byte
var modelRateLimitPolicy []byte
var runtimePolicyOverride []byte
var runtimeRetryPolicy []byte
var runtimeRateLimitPolicy []byte
var autoDisablePolicy []byte
var degradePolicy []byte
if err := rows.Scan(
&item.PlatformID,
&item.PlatformKey,
&item.PlatformName,
&item.Provider,
&item.SpecType,
&item.BaseURL,
&item.AuthType,
&credentials,
&platformConfig,
&item.DefaultPricingMode,
&item.DefaultDiscountFactor,
&item.PlatformPricingRuleSetID,
&platformRetryPolicy,
&platformRateLimitPolicy,
&item.PlatformPriority,
&item.PlatformModelID,
&item.BaseModelID,
&item.CanonicalModelKey,
&item.ProviderModelName,
&item.ModelName,
&item.ModelAlias,
&item.ModelType,
&item.DisplayName,
&capabilities,
&capabilityOverride,
&baseBilling,
&billing,
&billingOverride,
&item.PricingMode,
&item.DiscountFactor,
&item.ModelPricingRuleSetID,
&item.BasePricingRuleSetID,
&permissionConfig,
&modelRetryPolicy,
&modelRateLimitPolicy,
&item.RuntimePolicySetID,
&runtimePolicyOverride,
&runtimeRetryPolicy,
&runtimeRateLimitPolicy,
&autoDisablePolicy,
&degradePolicy,
); err != nil {
return nil, err
}
item.Credentials = decodeObject(credentials)
item.PlatformConfig = decodeObject(platformConfig)
item.PlatformRetryPolicy = decodeObject(platformRetryPolicy)
item.PlatformRateLimitPolicy = decodeObject(platformRateLimitPolicy)
item.Capabilities = decodeObject(capabilities)
item.CapabilityOverride = decodeObject(capabilityOverride)
item.BaseBillingConfig = decodeObject(baseBilling)
item.BillingConfig = decodeObject(billing)
item.BillingConfigOverride = decodeObject(billingOverride)
item.PermissionConfig = decodeObject(permissionConfig)
item.ModelRetryPolicy = decodeObject(modelRetryPolicy)
item.ModelRateLimitPolicy = decodeObject(modelRateLimitPolicy)
item.RuntimePolicyOverride = decodeObject(runtimePolicyOverride)
item.RuntimeRetryPolicy = decodeObject(runtimeRetryPolicy)
item.RuntimeRateLimitPolicy = decodeObject(runtimeRateLimitPolicy)
item.AutoDisablePolicy = decodeObject(autoDisablePolicy)
item.DegradePolicy = decodeObject(degradePolicy)
upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName)
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName)
item.QueueKey = item.ClientID
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
if len(items) == 0 {
if unavailableErr, err := s.modelCandidateCooldownError(ctx, model, modelType); err != nil {
return nil, err
} else if unavailableErr != nil {
return nil, unavailableErr
}
return nil, ErrNoModelCandidate
}
items, err = s.filterCandidatesByAccessRules(ctx, user, items)
if err != nil {
return nil, err
}
if len(items) == 0 {
return nil, ErrNoModelCandidate
}
return items, nil
}
func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) {
rows, err := s.pool.Query(ctx, `
SELECT p.name,
COALESCE(NULLIF(m.display_name, ''), NULLIF(m.model_alias, ''), m.model_name),
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
GREATEST(COALESCE(EXTRACT(EPOCH FROM p.cooldown_until - now()), 0), 0)::float8,
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
GREATEST(COALESCE(EXTRACT(EPOCH FROM m.cooldown_until - now()), 0), 0)::float8
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
AND m.model_type @> jsonb_build_array($2::text)
AND (
(COALESCE(m.model_alias, '') <> '' AND m.model_alias = $1::text)
OR (
COALESCE(m.model_alias, '') = ''
AND (
m.model_name = $1::text
OR b.canonical_model_key = $1::text
OR b.provider_model_name = $1::text
)
)
)
ORDER BY GREATEST(COALESCE(p.cooldown_until, to_timestamp(0)), COALESCE(m.cooldown_until, to_timestamp(0))) DESC,
p.priority ASC,
m.created_at ASC`, model, modelType)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var platformName string
var displayName string
var platformCooldownUntil string
var platformRemainingSeconds float64
var modelCooldownUntil string
var modelRemainingSeconds float64
if err := rows.Scan(
&platformName,
&displayName,
&platformCooldownUntil,
&platformRemainingSeconds,
&modelCooldownUntil,
&modelRemainingSeconds,
); err != nil {
return nil, err
}
if modelRemainingSeconds > 0 {
return &ModelCandidateUnavailableError{
Code: "model_cooling_down",
Message: cooldownErrorMessage("模型", displayName, modelRemainingSeconds, modelCooldownUntil),
}, nil
}
if platformRemainingSeconds > 0 {
return &ModelCandidateUnavailableError{
Code: "platform_cooling_down",
Message: cooldownErrorMessage("平台", platformName, platformRemainingSeconds, platformCooldownUntil),
}, nil
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return nil, nil
}
func cooldownErrorMessage(scope string, name string, remainingSeconds float64, cooldownUntil string) string {
name = strings.TrimSpace(name)
if name == "" {
name = "候选"
}
remainingMinutes := remainingSeconds / 60
if remainingMinutes < 0.1 {
remainingMinutes = 0.1
}
message := fmt.Sprintf("%s %s 冷却中,剩余 %.1f 分钟", scope, name, remainingMinutes)
if strings.TrimSpace(cooldownUntil) != "" {
message += ",预计恢复时间 " + cooldownUntil
}
return message
}