easyai-ai-gateway/apps/api/internal/store/rate_limit_status.go

544 lines
19 KiB
Go

package store
import (
"context"
"database/sql"
"sort"
"strconv"
"strings"
"time"
)
type RateLimitMetricStatus struct {
CurrentValue float64 `json:"currentValue"`
UsedValue float64 `json:"usedValue"`
ReservedValue float64 `json:"reservedValue"`
LimitValue float64 `json:"limitValue"`
Limited bool `json:"limited"`
Ratio float64 `json:"ratio"`
ResetAt string `json:"resetAt,omitempty"`
}
type ModelRateLimitStatus struct {
PlatformModelID string `json:"platformModelId"`
PlatformID string `json:"platformId"`
PlatformName string `json:"platformName"`
Provider string `json:"provider"`
PlatformStatus string `json:"platformStatus"`
PlatformDisabledReason *PlatformPolicyEvent `json:"platformDisabledReason,omitempty"`
PlatformPriority int `json:"platformPriority"`
PlatformDynamicPriority *int `json:"platformDynamicPriority,omitempty"`
PlatformEffectivePriority int `json:"platformEffectivePriority"`
ModelName string `json:"modelName"`
ProviderModelName string `json:"providerModelName,omitempty"`
ModelAlias string `json:"modelAlias,omitempty"`
DisplayName string `json:"displayName"`
ModelType []string `json:"modelType"`
Enabled bool `json:"enabled"`
RateLimitPolicy map[string]any `json:"rateLimitPolicy,omitempty"`
PlatformCooldownUntil string `json:"platformCooldownUntil,omitempty"`
ModelCooldownUntil string `json:"modelCooldownUntil,omitempty"`
Concurrent RateLimitMetricStatus `json:"concurrent"`
QueuedTasks float64 `json:"queuedTasks"`
RPM RateLimitMetricStatus `json:"rpm"`
TPM RateLimitMetricStatus `json:"tpm"`
LoadRatio float64 `json:"loadRatio"`
RecentPriorityDemotions []PriorityDemotionRecord `json:"recentPriorityDemotions,omitempty"`
}
type PriorityDemotionRecord struct {
ID string `json:"id"`
TaskID string `json:"taskId"`
PlatformID string `json:"platformId"`
PlatformModelID string `json:"platformModelId,omitempty"`
Reason string `json:"reason,omitempty"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
Category string `json:"category,omitempty"`
StatusCode int `json:"statusCode,omitempty"`
PolicySource string `json:"policySource,omitempty"`
Policy string `json:"policy,omitempty"`
PolicyRule string `json:"policyRule,omitempty"`
MatchedValue string `json:"matchedValue,omitempty"`
DynamicPriority int `json:"dynamicPriority,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
type PlatformPolicyEvent struct {
ID string `json:"id"`
TaskID string `json:"taskId"`
PlatformID string `json:"platformId"`
PlatformModelID string `json:"platformModelId,omitempty"`
EventType string `json:"eventType"`
Reason string `json:"reason,omitempty"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
Category string `json:"category,omitempty"`
StatusCode int `json:"statusCode,omitempty"`
PolicySource string `json:"policySource,omitempty"`
Policy string `json:"policy,omitempty"`
PolicyRule string `json:"policyRule,omitempty"`
MatchedValue string `json:"matchedValue,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
func (s *Store) ListModelRateLimitStatuses(ctx context.Context) ([]ModelRateLimitStatus, error) {
rows, err := s.pool.Query(ctx, `
SELECT m.id::text, m.platform_id::text, p.name, p.provider, p.status,
p.priority, p.dynamic_priority, COALESCE(p.dynamic_priority, p.priority),
m.model_name, COALESCE(NULLIF(m.provider_model_name, ''), m.model_name), COALESCE(m.model_alias, ''),
m.model_type, m.display_name, m.enabled,
p.rate_limit_policy, COALESCE(rp.rate_limit_policy, '{}'::jsonb), COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), m.rate_limit_policy,
COALESCE(to_char(p.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(to_char(m.cooldown_until AT TIME ZONE 'UTC', 'YYYY-MM-DD"T"HH24:MI:SS.MS"Z"'), ''),
COALESCE(con.active, 0)::float8,
COALESCE(queued.waiting, 0)::float8,
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, COALESCE(rpm.reset_at::text, ''),
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, COALESCE(tpm.reset_at::text, '')
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN (
SELECT scope_key, SUM(lease_value) AS active
FROM gateway_concurrency_leases
WHERE scope_type = 'platform_model'
AND released_at IS NULL
AND expires_at > now()
GROUP BY scope_key
) con ON con.scope_key = m.id::text
LEFT JOIN (
SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting
FROM (
SELECT t.id::text AS task_id, qm.id::text AS platform_model_id
FROM gateway_tasks t
JOIN integration_platforms qp ON TRUE
JOIN platform_models qm ON qm.platform_id = qp.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND NULLIF(t.model_type, '') IS NOT NULL
AND qm.model_type @> jsonb_build_array(t.model_type)
AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name)
AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id)
UNION ALL
SELECT latest.task_id::text AS task_id, latest.platform_model_id
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
) queued_sources
GROUP BY queued_sources.platform_model_id
) queued ON queued.platform_model_id = m.id::text
LEFT JOIN (
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value, reset_at
FROM gateway_rate_limit_counters
WHERE scope_type = 'platform_model'
AND metric = 'rpm'
AND reset_at > now()
ORDER BY scope_key, window_start DESC
) rpm ON rpm.scope_key = m.id::text
LEFT JOIN (
SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value, MAX(reset_at) AS reset_at
FROM gateway_rate_limit_counters
WHERE scope_type = 'platform_model'
AND metric LIKE 'tpm%'
AND reset_at > now()
GROUP BY scope_key
) tpm ON tpm.scope_key = m.id::text
WHERE p.deleted_at IS NULL
ORDER BY p.priority ASC, m.model_name ASC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]ModelRateLimitStatus, 0)
for rows.Next() {
var item ModelRateLimitStatus
var modelTypeBytes []byte
var platformPolicyBytes []byte
var runtimePolicyBytes []byte
var runtimeOverrideBytes []byte
var modelPolicyBytes []byte
var platformDynamicPriority sql.NullInt64
var platformCooldownUntil string
var modelCooldownUntil string
var concurrentCurrent float64
var queuedTasks float64
var rpmUsed float64
var rpmReserved float64
var rpmResetAt string
var tpmUsed float64
var tpmReserved float64
var tpmResetAt string
if err := rows.Scan(
&item.PlatformModelID,
&item.PlatformID,
&item.PlatformName,
&item.Provider,
&item.PlatformStatus,
&item.PlatformPriority,
&platformDynamicPriority,
&item.PlatformEffectivePriority,
&item.ModelName,
&item.ProviderModelName,
&item.ModelAlias,
&modelTypeBytes,
&item.DisplayName,
&item.Enabled,
&platformPolicyBytes,
&runtimePolicyBytes,
&runtimeOverrideBytes,
&modelPolicyBytes,
&platformCooldownUntil,
&modelCooldownUntil,
&concurrentCurrent,
&queuedTasks,
&rpmUsed,
&rpmReserved,
&rpmResetAt,
&tpmUsed,
&tpmReserved,
&tpmResetAt,
); err != nil {
return nil, err
}
item.PlatformDynamicPriority = intPointerFromNull(platformDynamicPriority)
item.ModelType = decodeStringArray(modelTypeBytes)
policy := effectiveModelRateLimitPolicy(
decodeObject(platformPolicyBytes),
decodeObject(runtimePolicyBytes),
decodeObject(runtimeOverrideBytes),
decodeObject(modelPolicyBytes),
)
item.PlatformCooldownUntil = platformCooldownUntil
item.ModelCooldownUntil = modelCooldownUntil
item.RateLimitPolicy = policy
item.QueuedTasks = queuedTasks
item.Concurrent = metricStatus(concurrentCurrent, concurrentCurrent, 0, rateLimitForMetric(policy, "concurrent"), "")
item.RPM = metricStatus(rpmUsed+rpmReserved, rpmUsed, rpmReserved, rateLimitForMetric(policy, "rpm"), rpmResetAt)
item.TPM = metricStatus(tpmUsed+tpmReserved, tpmUsed, tpmReserved, tpmLimit(policy), tpmResetAt)
item.LoadRatio = maxFloat(item.Concurrent.Ratio, item.RPM.Ratio, item.TPM.Ratio)
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, err
}
demotions, err := s.listRecentPriorityDemotionsByPlatform(ctx, items, 10)
if err != nil {
return nil, err
}
disabledReasons, err := s.listLatestPlatformDisabledReasons(ctx, items)
if err != nil {
return nil, err
}
for index := range items {
items[index].RecentPriorityDemotions = demotions[items[index].PlatformID]
if items[index].PlatformStatus != "enabled" {
items[index].PlatformDisabledReason = disabledReasons[items[index].PlatformID]
}
}
sort.SliceStable(items, func(i, j int) bool {
if items[i].LoadRatio == items[j].LoadRatio {
return strings.ToLower(items[i].DisplayName) < strings.ToLower(items[j].DisplayName)
}
return items[i].LoadRatio > items[j].LoadRatio
})
return items, nil
}
func (s *Store) listRecentPriorityDemotionsByPlatform(ctx context.Context, statuses []ModelRateLimitStatus, limit int) (map[string][]PriorityDemotionRecord, error) {
out := map[string][]PriorityDemotionRecord{}
if limit <= 0 || len(statuses) == 0 {
return out, nil
}
seen := map[string]bool{}
platformIDs := make([]string, 0, len(statuses))
for _, status := range statuses {
platformID := strings.TrimSpace(status.PlatformID)
if platformID == "" || seen[platformID] {
continue
}
seen[platformID] = true
platformIDs = append(platformIDs, platformID)
}
if len(platformIDs) == 0 {
return out, nil
}
rows, err := s.pool.Query(ctx, `
SELECT id::text, task_id::text, COALESCE(message, ''), payload, created_at
FROM (
SELECT e.*,
row_number() OVER (
PARTITION BY e.payload->>'platformId'
ORDER BY e.created_at DESC, e.seq DESC
) AS demotion_rank
FROM gateway_task_events e
WHERE e.event_type = 'task.policy.priority_demoted'
AND e.payload->>'platformId' = ANY($1::text[])
) ranked
WHERE demotion_rank <= $2
ORDER BY payload->>'platformId' ASC, created_at DESC`, platformIDs, limit)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
var taskID string
var message string
var payloadBytes []byte
var createdAt time.Time
if err := rows.Scan(&id, &taskID, &message, &payloadBytes, &createdAt); err != nil {
return nil, err
}
record := priorityDemotionRecordFromEventPayload(id, taskID, message, decodeObject(payloadBytes), createdAt)
if record.PlatformID == "" {
continue
}
out[record.PlatformID] = append(out[record.PlatformID], record)
}
return out, rows.Err()
}
func priorityDemotionRecordFromEventPayload(id string, taskID string, message string, payload map[string]any, createdAt time.Time) PriorityDemotionRecord {
errorMessage := stringValue(payload["errorMessage"])
if errorMessage == "" {
errorMessage = stringValue(payload["message"])
}
if errorMessage == "" {
errorMessage = strings.TrimSpace(message)
}
errorCode := stringValue(payload["errorCode"])
if errorCode == "" {
errorCode = stringValue(payload["code"])
}
return PriorityDemotionRecord{
ID: id,
TaskID: taskID,
PlatformID: stringValue(payload["platformId"]),
PlatformModelID: stringValue(payload["platformModelId"]),
Reason: stringValue(payload["reason"]),
ErrorCode: errorCode,
ErrorMessage: errorMessage,
Category: stringValue(payload["category"]),
StatusCode: intValue(payload["statusCode"]),
PolicySource: stringValue(payload["policySource"]),
Policy: stringValue(payload["policy"]),
PolicyRule: stringValue(payload["policyRule"]),
MatchedValue: stringValue(payload["matchedValue"]),
DynamicPriority: intValue(payload["dynamicPriority"]),
CreatedAt: createdAt,
}
}
func (s *Store) listLatestPlatformDisabledReasons(ctx context.Context, statuses []ModelRateLimitStatus) (map[string]*PlatformPolicyEvent, error) {
out := map[string]*PlatformPolicyEvent{}
seen := map[string]bool{}
platformIDs := make([]string, 0, len(statuses))
for _, status := range statuses {
platformID := strings.TrimSpace(status.PlatformID)
if platformID == "" || status.PlatformStatus == "enabled" || seen[platformID] {
continue
}
seen[platformID] = true
platformIDs = append(platformIDs, platformID)
}
if len(platformIDs) == 0 {
return out, nil
}
rows, err := s.pool.Query(ctx, `
SELECT id::text, task_id::text, event_type, COALESCE(message, ''), payload, COALESCE(attempt_error_message, ''), created_at
FROM (
SELECT e.*,
a.error_message AS attempt_error_message,
row_number() OVER (
PARTITION BY e.payload->>'platformId'
ORDER BY e.created_at DESC, e.seq DESC
) AS disabled_rank
FROM gateway_task_events e
LEFT JOIN LATERAL (
SELECT error_message
FROM gateway_task_attempts attempt
WHERE attempt.task_id = e.task_id
AND attempt.platform_id::text = e.payload->>'platformId'
ORDER BY attempt.attempt_no DESC, attempt.started_at DESC
LIMIT 1
) a ON TRUE
WHERE e.event_type IN ('task.policy.failover_disabled', 'task.policy.auto_disabled')
AND e.payload->>'platformId' = ANY($1::text[])
) ranked
WHERE disabled_rank = 1`, platformIDs)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
var taskID string
var eventType string
var message string
var payloadBytes []byte
var attemptErrorMessage string
var createdAt time.Time
if err := rows.Scan(&id, &taskID, &eventType, &message, &payloadBytes, &attemptErrorMessage, &createdAt); err != nil {
return nil, err
}
record := platformPolicyEventFromPayload(id, taskID, eventType, message, attemptErrorMessage, decodeObject(payloadBytes), createdAt)
if record.PlatformID == "" {
continue
}
out[record.PlatformID] = &record
}
return out, rows.Err()
}
func platformPolicyEventFromPayload(id string, taskID string, eventType string, message string, attemptErrorMessage string, payload map[string]any, createdAt time.Time) PlatformPolicyEvent {
errorMessage := stringValue(payload["errorMessage"])
if errorMessage == "" {
errorMessage = stringValue(payload["message"])
}
if errorMessage == "" {
errorMessage = strings.TrimSpace(attemptErrorMessage)
}
if errorMessage == "" {
errorMessage = strings.TrimSpace(message)
}
errorCode := stringValue(payload["errorCode"])
if errorCode == "" {
errorCode = stringValue(payload["code"])
}
return PlatformPolicyEvent{
ID: id,
TaskID: taskID,
PlatformID: stringValue(payload["platformId"]),
PlatformModelID: stringValue(payload["platformModelId"]),
EventType: eventType,
Reason: stringValue(payload["reason"]),
ErrorCode: errorCode,
ErrorMessage: errorMessage,
Category: stringValue(payload["category"]),
StatusCode: intValue(payload["statusCode"]),
PolicySource: stringValue(payload["policySource"]),
Policy: stringValue(payload["policy"]),
PolicyRule: stringValue(payload["policyRule"]),
MatchedValue: stringValue(payload["matchedValue"]),
CreatedAt: createdAt,
}
}
func effectiveModelRateLimitPolicy(platformPolicy map[string]any, runtimePolicy map[string]any, runtimeOverride map[string]any, modelPolicy map[string]any) map[string]any {
policy := platformPolicy
if hasRateLimitRules(runtimePolicy) {
policy = shallowMergeMap(policy, runtimePolicy)
}
if nested, ok := runtimeOverride["rateLimitPolicy"].(map[string]any); ok && len(nested) > 0 {
policy = shallowMergeMap(policy, nested)
}
if hasRateLimitRules(modelPolicy) {
policy = shallowMergeMap(policy, modelPolicy)
}
if hasRateLimitRules(policy) {
return policy
}
return nil
}
func hasRateLimitRules(policy map[string]any) bool {
rules, _ := policy["rules"].([]any)
return len(rules) > 0
}
func shallowMergeMap(base map[string]any, override map[string]any) map[string]any {
out := map[string]any{}
for key, value := range base {
out[key] = value
}
for key, value := range override {
out[key] = value
}
return out
}
func rateLimitForMetric(policy map[string]any, metric string) float64 {
rules, _ := policy["rules"].([]any)
for _, rawRule := range rules {
rule, _ := rawRule.(map[string]any)
if strings.TrimSpace(stringValue(rule["metric"])) == metric {
return floatValue(rule["limit"])
}
}
return 0
}
func tpmLimit(policy map[string]any) float64 {
if limit := rateLimitForMetric(policy, "tpm_total"); limit > 0 {
return limit
}
return rateLimitForMetric(policy, "tpm_input") + rateLimitForMetric(policy, "tpm_output")
}
func metricStatus(current float64, used float64, reserved float64, limit float64, resetAt string) RateLimitMetricStatus {
status := RateLimitMetricStatus{
CurrentValue: current,
UsedValue: used,
ReservedValue: reserved,
LimitValue: limit,
Limited: limit > 0,
ResetAt: resetAt,
}
if status.Limited {
status.Ratio = current / limit
}
return status
}
func maxFloat(values ...float64) float64 {
out := 0.0
for _, value := range values {
if value > out {
out = value
}
}
return out
}
func stringValue(value any) string {
text, _ := value.(string)
return strings.TrimSpace(text)
}
func floatValue(value any) float64 {
switch typed := value.(type) {
case int:
return float64(typed)
case int64:
return float64(typed)
case float64:
return typed
default:
return 0
}
}
func intValue(value any) int {
switch typed := value.(type) {
case int:
return typed
case int64:
return int(typed)
case float64:
return int(typed)
case string:
parsed, _ := strconv.Atoi(strings.TrimSpace(typed))
return parsed
default:
return 0
}
}