feat(api): add load-aware client fallback

This commit is contained in:
wangbo 2026-05-12 16:59:51 +08:00
parent c5cede2359
commit c2696e7bbe
7 changed files with 558 additions and 5 deletions

View File

@ -1193,6 +1193,8 @@ WHERE m.platform_id = $1::uuid
t.Fatalf("failover events should include retrying event status=%d body=%s", resp.StatusCode, string(body))
}
assertLoadAvoidanceSimulatedRetryChain(t, ctx, testPool, server.URL, loginResponse.AccessToken, apiKeyResponse.Secret, suffixText)
var callbackRows int
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_task_callback_outbox WHERE task_id = $1::uuid`, taskResponse.Task.ID).Scan(&callbackRows); err != nil {
t.Fatalf("read callback outbox: %v", err)
@ -1396,6 +1398,212 @@ func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string
return detail
}
func assertLoadAvoidanceSimulatedRetryChain(t *testing.T, ctx context.Context, testPool *pgxpool.Pool, baseURL string, adminToken string, runtimeToken string, suffixText string) {
t.Helper()
model := "load-avoidance-smoke-" + suffixText
type scenario struct {
keySuffix string
name string
failure string
priority int
full bool
}
scenarios := []scenario{
{keySuffix: "hard-stop", name: "Load Avoidance Hard Stop", failure: "fatal_failure", priority: 20},
{keySuffix: "retryable", name: "Load Avoidance Retryable", failure: "retryable_failure", priority: 30},
{keySuffix: "full-rate-limit", name: "Load Avoidance Full Rate Limit", failure: "rate_limit", priority: 1, full: true},
{keySuffix: "full-overloaded", name: "Load Avoidance Full Overloaded", failure: "overloaded", priority: 2, full: true},
{keySuffix: "full-fatal", name: "Load Avoidance Full Fatal", failure: "fatal_failure", priority: 3, full: true},
}
for _, item := range scenarios {
var platform struct {
ID string `json:"id"`
PlatformKey string `json:"platformKey"`
Name string `json:"name"`
}
credentials := map[string]any{"mode": "simulation"}
if item.failure != "" {
credentials["simulationFailure"] = item.failure
}
doJSON(t, baseURL, http.MethodPost, "/api/admin/platforms", adminToken, map[string]any{
"provider": "openai",
"platformKey": "openai-load-" + item.keySuffix + "-" + suffixText,
"name": item.name,
"baseUrl": "https://api.openai.com/v1",
"authType": "bearer",
"credentials": credentials,
"priority": item.priority,
}, http.StatusCreated, &platform)
if platform.ID == "" || platform.PlatformKey == "" {
t.Fatalf("load avoidance platform was not created: %+v", platform)
}
var platformModel struct {
ID string `json:"id"`
}
payload := map[string]any{
"canonicalModelKey": "openai:gpt-4o-mini",
"modelName": model,
"providerModelName": model,
"modelAlias": model,
"modelType": []string{"text_generate"},
"displayName": item.name,
"retryPolicy": map[string]any{"enabled": true, "maxAttempts": 1},
}
if item.full {
payload["rateLimitPolicy"] = map[string]any{
"rules": []map[string]any{
{"metric": "rpm", "limit": 10, "windowSeconds": 60},
{"metric": "tpm_total", "limit": 200, "windowSeconds": 60},
{"metric": "concurrent", "limit": 1, "leaseTtlSeconds": 120},
},
}
}
doJSON(t, baseURL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", adminToken, payload, http.StatusCreated, &platformModel)
if platformModel.ID == "" {
t.Fatalf("load avoidance platform model was not created for %s: %+v", item.name, platformModel)
}
if item.full {
seedQueuedConcurrencyLoad(t, ctx, testPool, platform.PlatformKey, model, platformModel.ID)
}
}
var taskResponse struct {
Task struct {
ID string `json:"id"`
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
} `json:"task"`
}
doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", runtimeToken, map[string]any{
"model": model,
"runMode": "simulation",
"simulation": true,
"simulationDurationMs": 5,
"messages": []map[string]any{{"role": "user", "content": "load avoidance retry chain"}},
}, http.StatusAccepted, &taskResponse)
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "failed" || taskResponse.Task.ErrorCode != "bad_request" {
t.Fatalf("load avoidance task should only fail after avoided clients are retried, got %+v", taskResponse.Task)
}
var detail struct {
ID string `json:"id"`
Status string `json:"status"`
ErrorCode string `json:"errorCode"`
Attempts []struct {
AttemptNo int `json:"attemptNo"`
PlatformName string `json:"platformName"`
Status string `json:"status"`
Retryable bool `json:"retryable"`
ErrorCode string `json:"errorCode"`
Metrics map[string]any `json:"metrics"`
} `json:"attempts"`
Metrics map[string]any `json:"metrics"`
}
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskResponse.Task.ID, runtimeToken, nil, http.StatusOK, &detail)
if detail.Status != "failed" || len(detail.Attempts) != len(scenarios) {
t.Fatalf("load avoidance detail should expose every attempted client, got status=%s attempts=%+v", detail.Status, detail.Attempts)
}
expected := []struct {
name string
code string
retryable bool
avoided bool
}{
{name: "Load Avoidance Hard Stop", code: "bad_request"},
{name: "Load Avoidance Retryable", code: "server_error", retryable: true},
{name: "Load Avoidance Full Rate Limit", code: "rate_limit", retryable: true, avoided: true},
{name: "Load Avoidance Full Overloaded", code: "overloaded", retryable: true, avoided: true},
{name: "Load Avoidance Full Fatal", code: "bad_request", avoided: true},
}
attemptSummary := make([]string, 0, len(detail.Attempts))
for index, want := range expected {
got := detail.Attempts[index]
attemptSummary = append(attemptSummary, got.PlatformName+":"+got.Status+":"+got.ErrorCode)
if got.AttemptNo != index+1 || got.PlatformName != want.name || got.Status != "failed" || got.ErrorCode != want.code || got.Retryable != want.retryable {
t.Fatalf("unexpected load avoidance attempt %d: got %+v want %+v", index+1, got, want)
}
if boolFromTestMap(got.Metrics, "loadAvoided") != want.avoided {
t.Fatalf("loadAvoided mismatch for %s metrics=%+v", got.PlatformName, got.Metrics)
}
if !want.avoided && floatFromTestAny(got.Metrics["loadRatio"]) != 0 {
t.Fatalf("non-full candidate should not carry load pressure, got %s metrics=%+v", got.PlatformName, got.Metrics)
}
if want.avoided {
if ratio := floatFromTestAny(got.Metrics["loadRatio"]); ratio < 1 {
t.Fatalf("avoided candidate should expose full load ratio, got %s ratio=%v metrics=%+v", got.PlatformName, ratio, got.Metrics)
}
loadMetrics, _ := got.Metrics["loadMetrics"].(map[string]any)
concurrent, _ := loadMetrics["concurrent"].(map[string]any)
if ratio := floatFromTestAny(concurrent["ratio"]); ratio < 1 {
t.Fatalf("avoided candidate should expose concurrent saturation, got %s loadMetrics=%+v", got.PlatformName, loadMetrics)
}
if queued := floatFromTestAny(loadMetrics["queued"]); queued < 1 {
t.Fatalf("avoided candidate should expose queued waiting load, got %s loadMetrics=%+v", got.PlatformName, loadMetrics)
}
}
}
if !attemptTraceHasReason(detail.Attempts[0].Metrics, "load_avoidance_fallback") {
t.Fatalf("first hard-stop attempt should continue because avoided candidates remain, metrics=%+v", detail.Attempts[0].Metrics)
}
if summary, ok := detail.Metrics["attempts"].([]any); !ok || len(summary) != len(scenarios) {
t.Fatalf("final task metrics should preserve load-avoidance attempt summary, got %+v", detail.Metrics)
}
t.Logf("load avoidance retry chain: %s", strings.Join(attemptSummary, " -> "))
}
func seedQueuedConcurrencyLoad(t *testing.T, ctx context.Context, testPool *pgxpool.Pool, platformKey string, model string, platformModelID string) {
t.Helper()
queueKey := platformKey + ":text_generate:" + model
if _, err := testPool.Exec(ctx, `
INSERT INTO gateway_tasks (
kind, run_mode, user_id, model, requested_model, model_type,
request, normalized_request, status, queue_key, priority, async_mode,
next_run_at, result, billings
)
VALUES (
'chat.completions', 'simulation', $1, $2, $2, 'text_generate',
'{}'::jsonb, '{}'::jsonb, 'queued', $3, 999, true,
now() + interval '1 hour', '{}'::jsonb, '[]'::jsonb
)`, "load-avoidance-seed-"+platformModelID, model, queueKey); err != nil {
t.Fatalf("seed queued load for %s: %v", queueKey, err)
}
}
func boolFromTestMap(values map[string]any, key string) bool {
value, _ := values[key].(bool)
return value
}
func floatFromTestAny(value any) float64 {
switch typed := value.(type) {
case float64:
return typed
case float32:
return float64(typed)
case int:
return float64(typed)
case int64:
return float64(typed)
case json.Number:
out, _ := typed.Float64()
return out
default:
return 0
}
}
func attemptTraceHasReason(metrics map[string]any, reason string) bool {
trace, _ := metrics["trace"].([]any)
for _, raw := range trace {
item, _ := raw.(map[string]any)
if item["reason"] == reason {
return true
}
}
return false
}
func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) {
t.Helper()
if windowSeconds <= 0 {

View File

@ -112,8 +112,30 @@ func attemptMetrics(candidate store.RuntimeModelCandidate, attemptNo int, simula
"platformModelId": candidate.PlatformModelID,
"clientId": candidate.ClientID,
"queueKey": candidate.QueueKey,
"loadRatio": candidate.LoadRatio,
"loadAvoided": candidate.LoadAvoided,
"simulated": simulated,
}
if candidate.LoadLimited {
metrics["loadMetrics"] = map[string]any{
"rpm": map[string]any{
"current": candidate.LoadMetrics.RPMCurrent,
"limit": candidate.LoadMetrics.RPMLimit,
"ratio": candidate.LoadMetrics.RPMRatio,
},
"tpm": map[string]any{
"current": candidate.LoadMetrics.TPMCurrent,
"limit": candidate.LoadMetrics.TPMLimit,
"ratio": candidate.LoadMetrics.TPMRatio,
},
"concurrent": map[string]any{
"current": candidate.LoadMetrics.ConcurrentCurrent,
"limit": candidate.LoadMetrics.ConcurrentLimit,
"ratio": candidate.LoadMetrics.ConcurrentRatio,
},
"queued": candidate.LoadMetrics.QueuedCount,
}
}
if attemptNo > 0 {
metrics["attempt"] = attemptNo
}

View File

@ -46,6 +46,24 @@ func TestFailoverTimeBudgetExceeded(t *testing.T) {
}
}
func TestLoadAvoidanceFallbackContinuesToAvoidedCandidate(t *testing.T) {
candidates := []store.RuntimeModelCandidate{
{PlatformID: "available-candidate"},
{PlatformID: "avoided-full-candidate", LoadAvoided: true},
}
if !hasLoadAvoidanceFallback(candidates, 0, 99) {
t.Fatal("expected non-avoided candidate to fall back to later avoided candidate")
}
if hasLoadAvoidanceFallback(candidates, 1, 99) {
t.Fatal("avoided candidate should not force another load-avoidance fallback")
}
decision := loadAvoidanceFallbackDecision(&clients.ClientError{Code: "bad_request", StatusCode: 400, Retryable: false})
if !decision.Retry || decision.Reason != "load_avoidance_fallback" || decision.Action != "next" {
t.Fatalf("expected active load avoidance fallback to force next candidate, got %+v", decision)
}
}
func TestFailoverHardStopBeatsModelOverride(t *testing.T) {
runnerPolicy := store.RunnerPolicy{
Status: "active",

View File

@ -290,6 +290,9 @@ candidatesLoop:
break
}
decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr)
if !decision.Retry && hasLoadAvoidanceFallback(candidates, index, maxPlatforms) {
decision = loadAvoidanceFallbackDecision(candidateErr)
}
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision))
if !decision.Retry {
break
@ -792,6 +795,37 @@ func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool
return maxDuration > 0 && time.Since(start) >= maxDuration
}
func hasLoadAvoidanceFallback(candidates []store.RuntimeModelCandidate, index int, maxPlatforms int) bool {
if index < 0 || index >= len(candidates) || candidates[index].LoadAvoided {
return false
}
limit := len(candidates)
if maxPlatforms > 0 && maxPlatforms < limit {
limit = maxPlatforms
}
for next := index + 1; next < limit; next++ {
if candidates[next].LoadAvoided {
return true
}
}
return false
}
func loadAvoidanceFallbackDecision(err error) failoverDecision {
return failoverDecision{
Retry: true,
Action: "next",
Reason: "load_avoidance_fallback",
Match: policyRuleMatch{
Source: "runtime_client_load",
Policy: "loadAvoidance",
Rule: "fallback",
Value: "loadRatio>=1",
},
Info: failureInfoFromError(err),
}
}
func normalizeRequest(kind string, body map[string]any) map[string]any {
out := cloneMap(body)
if kind == "responses" && out["messages"] == nil && out["input"] != nil {

View File

@ -3,6 +3,7 @@ package store
import (
"context"
"fmt"
"sort"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
@ -23,10 +24,18 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
COALESCE(b.pricing_rule_set_id::text, ''),
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb)
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb),
COALESCE(con.active, 0)::float8,
COALESCE(queued.waiting, 0)::float8,
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8,
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8,
COALESCE(s.running_count, 0)::float8,
COALESCE(s.waiting_count, 0)::float8,
COALESCE(s.limiter_ratio, 0)::float8,
COALESCE(EXTRACT(EPOCH FROM s.last_assigned_at), 0)::float8
FROM platform_models m
JOIN integration_platforms p ON p.id = m.platform_id
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
@ -34,6 +43,57 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
LEFT JOIN runtime_client_states s
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
LEFT JOIN (
SELECT scope_key, SUM(lease_value) AS active
FROM gateway_concurrency_leases
WHERE scope_type = 'platform_model'
AND released_at IS NULL
AND expires_at > now()
GROUP BY scope_key
) con ON con.scope_key = m.id::text
LEFT JOIN (
SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting
FROM (
SELECT t.id::text AS task_id, qm.id::text AS platform_model_id
FROM gateway_tasks t
JOIN integration_platforms qp ON TRUE
JOIN platform_models qm ON qm.platform_id = qp.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND NULLIF(t.model_type, '') IS NOT NULL
AND qm.model_type @> jsonb_build_array(t.model_type)
AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name)
AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id)
UNION ALL
SELECT latest.task_id::text AS task_id, latest.platform_model_id
FROM (
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
FROM gateway_tasks t
JOIN gateway_task_attempts a ON a.task_id = t.id
WHERE t.async_mode = true
AND t.status = 'queued'
AND a.platform_model_id IS NOT NULL
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
) latest
) queued_sources
GROUP BY queued_sources.platform_model_id
) queued ON queued.platform_model_id = m.id::text
LEFT JOIN (
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value
FROM gateway_rate_limit_counters
WHERE scope_type = 'platform_model'
AND metric = 'rpm'
AND reset_at > now()
ORDER BY scope_key, window_start DESC
) rpm ON rpm.scope_key = m.id::text
LEFT JOIN (
SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value
FROM gateway_rate_limit_counters
WHERE scope_type = 'platform_model'
AND metric LIKE 'tpm%'
AND reset_at > now()
GROUP BY scope_key
) tpm ON tpm.scope_key = m.id::text
WHERE p.status = 'enabled'
AND p.deleted_at IS NULL
AND m.enabled = true
@ -52,7 +112,6 @@ WHERE p.status = 'enabled'
)
)
ORDER BY effective_priority ASC,
COALESCE(s.limiter_ratio, 0) ASC,
COALESCE(s.running_count, 0) ASC,
COALESCE(s.waiting_count, 0) ASC,
COALESCE(s.last_assigned_at, to_timestamp(0)) ASC,
@ -82,6 +141,16 @@ ORDER BY effective_priority ASC,
var runtimeRateLimitPolicy []byte
var autoDisablePolicy []byte
var degradePolicy []byte
var concurrentActive float64
var queuedWaiting float64
var rpmUsed float64
var rpmReserved float64
var tpmUsed float64
var tpmReserved float64
var stateRunningCount float64
var stateWaitingCount float64
var stateLimiterRatio float64
var lastAssignedUnix float64
if err := rows.Scan(
&item.PlatformID,
&item.PlatformKey,
@ -124,6 +193,16 @@ ORDER BY effective_priority ASC,
&runtimeRateLimitPolicy,
&autoDisablePolicy,
&degradePolicy,
&concurrentActive,
&queuedWaiting,
&rpmUsed,
&rpmReserved,
&tpmUsed,
&tpmReserved,
&stateRunningCount,
&stateWaitingCount,
&stateLimiterRatio,
&lastAssignedUnix,
); err != nil {
return nil, err
}
@ -147,6 +226,21 @@ ORDER BY effective_priority ASC,
upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName)
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName)
item.QueueKey = item.ClientID
item.RunningCount = stateRunningCount
item.WaitingCount = maxFloat(queuedWaiting, stateWaitingCount)
item.LastAssignedUnix = lastAssignedUnix
applyRuntimeCandidateLoad(&item, runtimeCandidateLoadInput{
Policy: effectiveModelRateLimitPolicy(item.PlatformRateLimitPolicy, item.RuntimeRateLimitPolicy, item.RuntimePolicyOverride, item.ModelRateLimitPolicy),
ConcurrentActive: concurrentActive,
QueuedWaiting: queuedWaiting,
RPMUsed: rpmUsed,
RPMReserved: rpmReserved,
TPMUsed: tpmUsed,
TPMReserved: tpmReserved,
StateRunningCount: stateRunningCount,
StateWaitingCount: stateWaitingCount,
StateLimiterRatio: stateLimiterRatio,
})
items = append(items, item)
}
if err := rows.Err(); err != nil {
@ -167,9 +261,102 @@ ORDER BY effective_priority ASC,
if len(items) == 0 {
return nil, ErrNoModelCandidate
}
sortRuntimeModelCandidates(items)
return items, nil
}
type runtimeCandidateLoadInput struct {
Policy map[string]any
ConcurrentActive float64
QueuedWaiting float64
RPMUsed float64
RPMReserved float64
TPMUsed float64
TPMReserved float64
StateRunningCount float64
StateWaitingCount float64
StateLimiterRatio float64
}
func applyRuntimeCandidateLoad(candidate *RuntimeModelCandidate, input runtimeCandidateLoadInput) {
rpmLimit := rateLimitForMetric(input.Policy, "rpm")
tpmLimitValue := tpmLimit(input.Policy)
concurrentLimit := rateLimitForMetric(input.Policy, "concurrent")
rpmCurrent := input.RPMUsed + input.RPMReserved
tpmCurrent := input.TPMUsed + input.TPMReserved
concurrentCurrent := input.ConcurrentActive + input.QueuedWaiting
metrics := RuntimeCandidateLoadMetrics{
RPMCurrent: rpmCurrent,
RPMLimit: rpmLimit,
RPMRatio: ratioIfLimited(rpmCurrent, rpmLimit),
TPMCurrent: tpmCurrent,
TPMLimit: tpmLimitValue,
TPMRatio: ratioIfLimited(tpmCurrent, tpmLimitValue),
ConcurrentCurrent: concurrentCurrent,
ConcurrentLimit: concurrentLimit,
ConcurrentRatio: ratioIfLimited(concurrentCurrent, concurrentLimit),
QueuedCount: input.QueuedWaiting,
StateRunningCount: input.StateRunningCount,
StateWaitingCount: input.StateWaitingCount,
StateLimiterRatio: input.StateLimiterRatio,
}
candidate.LoadMetrics = metrics
candidate.LoadLimited = rpmLimit > 0 || tpmLimitValue > 0 || concurrentLimit > 0
candidate.LoadRatio = maxFloat(metrics.RPMRatio, metrics.TPMRatio, metrics.ConcurrentRatio)
}
func ratioIfLimited(current float64, limit float64) float64 {
if limit <= 0 {
return 0
}
return current / limit
}
func sortRuntimeModelCandidates(items []RuntimeModelCandidate) {
hasFull := false
hasNonFull := false
for index := range items {
items[index].LoadAvoided = false
if runtimeCandidateFull(items[index]) {
hasFull = true
} else {
hasNonFull = true
}
}
if hasFull && hasNonFull {
for index := range items {
items[index].LoadAvoided = runtimeCandidateFull(items[index])
}
}
sort.SliceStable(items, func(i, j int) bool {
aFull := runtimeCandidateFull(items[i])
bFull := runtimeCandidateFull(items[j])
if aFull != bFull {
return !aFull
}
if items[i].PlatformPriority != items[j].PlatformPriority {
return items[i].PlatformPriority < items[j].PlatformPriority
}
if items[i].LoadRatio != items[j].LoadRatio {
return items[i].LoadRatio < items[j].LoadRatio
}
if items[i].RunningCount != items[j].RunningCount {
return items[i].RunningCount < items[j].RunningCount
}
if items[i].WaitingCount != items[j].WaitingCount {
return items[i].WaitingCount < items[j].WaitingCount
}
if items[i].LastAssignedUnix != items[j].LastAssignedUnix {
return items[i].LastAssignedUnix < items[j].LastAssignedUnix
}
return false
})
}
func runtimeCandidateFull(candidate RuntimeModelCandidate) bool {
return candidate.LoadLimited && candidate.LoadRatio >= 1
}
func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) {
rows, err := s.pool.Query(ctx, `
SELECT p.name,

View File

@ -0,0 +1,61 @@
package store
import "testing"
func TestRuntimeCandidateLoadUsesMaxLimitedMetric(t *testing.T) {
candidate := RuntimeModelCandidate{}
applyRuntimeCandidateLoad(&candidate, runtimeCandidateLoadInput{
Policy: map[string]any{"rules": []any{
map[string]any{"metric": "rpm", "limit": 100},
map[string]any{"metric": "tpm_total", "limit": 1000},
map[string]any{"metric": "concurrent", "limit": 10},
}},
RPMUsed: 40,
RPMReserved: 10,
TPMUsed: 900,
ConcurrentActive: 3,
QueuedWaiting: 2,
})
if !candidate.LoadLimited {
t.Fatal("expected load to be limited when rate limit rules exist")
}
if candidate.LoadMetrics.RPMRatio != 0.5 {
t.Fatalf("expected rpm ratio 0.5, got %v", candidate.LoadMetrics.RPMRatio)
}
if candidate.LoadMetrics.TPMRatio != 0.9 {
t.Fatalf("expected tpm ratio 0.9, got %v", candidate.LoadMetrics.TPMRatio)
}
if candidate.LoadMetrics.ConcurrentRatio != 0.5 {
t.Fatalf("expected concurrent ratio 0.5, got %v", candidate.LoadMetrics.ConcurrentRatio)
}
if candidate.LoadRatio != 0.9 {
t.Fatalf("expected max load ratio 0.9, got %v", candidate.LoadRatio)
}
}
func TestRuntimeCandidateSortingAvoidsFullCandidatesButKeepsFallback(t *testing.T) {
candidates := []RuntimeModelCandidate{
{
PlatformID: "high-priority-full",
PlatformPriority: 1,
LoadLimited: true,
LoadRatio: 1.2,
},
{
PlatformID: "lower-priority-available",
PlatformPriority: 50,
LoadLimited: true,
LoadRatio: 0.2,
},
}
sortRuntimeModelCandidates(candidates)
if candidates[0].PlatformID != "lower-priority-available" {
t.Fatalf("expected non-full candidate to be tried first, got %+v", candidates)
}
if candidates[1].PlatformID != "high-priority-full" || !candidates[1].LoadAvoided {
t.Fatalf("expected full high-priority candidate to remain as avoided fallback, got %+v", candidates)
}
}

View File

@ -137,6 +137,29 @@ type RuntimeModelCandidate struct {
DegradePolicy map[string]any
ClientID string
QueueKey string
LoadRatio float64
LoadLimited bool
LoadAvoided bool
LoadMetrics RuntimeCandidateLoadMetrics
RunningCount float64
WaitingCount float64
LastAssignedUnix float64
}
type RuntimeCandidateLoadMetrics struct {
RPMCurrent float64
RPMLimit float64
RPMRatio float64
TPMCurrent float64
TPMLimit float64
TPMRatio float64
ConcurrentCurrent float64
ConcurrentLimit float64
ConcurrentRatio float64
QueuedCount float64
StateRunningCount float64
StateWaitingCount float64
StateLimiterRatio float64
}
type RateLimitReservation struct {