diff --git a/apps/api/internal/httpapi/core_flow_integration_test.go b/apps/api/internal/httpapi/core_flow_integration_test.go index e5198f4..1e2cef7 100644 --- a/apps/api/internal/httpapi/core_flow_integration_test.go +++ b/apps/api/internal/httpapi/core_flow_integration_test.go @@ -1193,6 +1193,8 @@ WHERE m.platform_id = $1::uuid t.Fatalf("failover events should include retrying event status=%d body=%s", resp.StatusCode, string(body)) } + assertLoadAvoidanceSimulatedRetryChain(t, ctx, testPool, server.URL, loginResponse.AccessToken, apiKeyResponse.Secret, suffixText) + var callbackRows int if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_task_callback_outbox WHERE task_id = $1::uuid`, taskResponse.Task.ID).Scan(&callbackRows); err != nil { t.Fatalf("read callback outbox: %v", err) @@ -1396,6 +1398,212 @@ func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string return detail } +func assertLoadAvoidanceSimulatedRetryChain(t *testing.T, ctx context.Context, testPool *pgxpool.Pool, baseURL string, adminToken string, runtimeToken string, suffixText string) { + t.Helper() + model := "load-avoidance-smoke-" + suffixText + type scenario struct { + keySuffix string + name string + failure string + priority int + full bool + } + scenarios := []scenario{ + {keySuffix: "hard-stop", name: "Load Avoidance Hard Stop", failure: "fatal_failure", priority: 20}, + {keySuffix: "retryable", name: "Load Avoidance Retryable", failure: "retryable_failure", priority: 30}, + {keySuffix: "full-rate-limit", name: "Load Avoidance Full Rate Limit", failure: "rate_limit", priority: 1, full: true}, + {keySuffix: "full-overloaded", name: "Load Avoidance Full Overloaded", failure: "overloaded", priority: 2, full: true}, + {keySuffix: "full-fatal", name: "Load Avoidance Full Fatal", failure: "fatal_failure", priority: 3, full: true}, + } + for _, item := range scenarios { + var platform struct { + ID string `json:"id"` + PlatformKey string `json:"platformKey"` + Name string `json:"name"` + } + credentials := map[string]any{"mode": "simulation"} + if item.failure != "" { + credentials["simulationFailure"] = item.failure + } + doJSON(t, baseURL, http.MethodPost, "/api/admin/platforms", adminToken, map[string]any{ + "provider": "openai", + "platformKey": "openai-load-" + item.keySuffix + "-" + suffixText, + "name": item.name, + "baseUrl": "https://api.openai.com/v1", + "authType": "bearer", + "credentials": credentials, + "priority": item.priority, + }, http.StatusCreated, &platform) + if platform.ID == "" || platform.PlatformKey == "" { + t.Fatalf("load avoidance platform was not created: %+v", platform) + } + + var platformModel struct { + ID string `json:"id"` + } + payload := map[string]any{ + "canonicalModelKey": "openai:gpt-4o-mini", + "modelName": model, + "providerModelName": model, + "modelAlias": model, + "modelType": []string{"text_generate"}, + "displayName": item.name, + "retryPolicy": map[string]any{"enabled": true, "maxAttempts": 1}, + } + if item.full { + payload["rateLimitPolicy"] = map[string]any{ + "rules": []map[string]any{ + {"metric": "rpm", "limit": 10, "windowSeconds": 60}, + {"metric": "tpm_total", "limit": 200, "windowSeconds": 60}, + {"metric": "concurrent", "limit": 1, "leaseTtlSeconds": 120}, + }, + } + } + doJSON(t, baseURL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", adminToken, payload, http.StatusCreated, &platformModel) + if platformModel.ID == "" { + t.Fatalf("load avoidance platform model was not created for %s: %+v", item.name, platformModel) + } + if item.full { + seedQueuedConcurrencyLoad(t, ctx, testPool, platform.PlatformKey, model, platformModel.ID) + } + } + + var taskResponse struct { + Task struct { + ID string `json:"id"` + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + } `json:"task"` + } + doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", runtimeToken, map[string]any{ + "model": model, + "runMode": "simulation", + "simulation": true, + "simulationDurationMs": 5, + "messages": []map[string]any{{"role": "user", "content": "load avoidance retry chain"}}, + }, http.StatusAccepted, &taskResponse) + if taskResponse.Task.ID == "" || taskResponse.Task.Status != "failed" || taskResponse.Task.ErrorCode != "bad_request" { + t.Fatalf("load avoidance task should only fail after avoided clients are retried, got %+v", taskResponse.Task) + } + + var detail struct { + ID string `json:"id"` + Status string `json:"status"` + ErrorCode string `json:"errorCode"` + Attempts []struct { + AttemptNo int `json:"attemptNo"` + PlatformName string `json:"platformName"` + Status string `json:"status"` + Retryable bool `json:"retryable"` + ErrorCode string `json:"errorCode"` + Metrics map[string]any `json:"metrics"` + } `json:"attempts"` + Metrics map[string]any `json:"metrics"` + } + doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskResponse.Task.ID, runtimeToken, nil, http.StatusOK, &detail) + if detail.Status != "failed" || len(detail.Attempts) != len(scenarios) { + t.Fatalf("load avoidance detail should expose every attempted client, got status=%s attempts=%+v", detail.Status, detail.Attempts) + } + expected := []struct { + name string + code string + retryable bool + avoided bool + }{ + {name: "Load Avoidance Hard Stop", code: "bad_request"}, + {name: "Load Avoidance Retryable", code: "server_error", retryable: true}, + {name: "Load Avoidance Full Rate Limit", code: "rate_limit", retryable: true, avoided: true}, + {name: "Load Avoidance Full Overloaded", code: "overloaded", retryable: true, avoided: true}, + {name: "Load Avoidance Full Fatal", code: "bad_request", avoided: true}, + } + attemptSummary := make([]string, 0, len(detail.Attempts)) + for index, want := range expected { + got := detail.Attempts[index] + attemptSummary = append(attemptSummary, got.PlatformName+":"+got.Status+":"+got.ErrorCode) + if got.AttemptNo != index+1 || got.PlatformName != want.name || got.Status != "failed" || got.ErrorCode != want.code || got.Retryable != want.retryable { + t.Fatalf("unexpected load avoidance attempt %d: got %+v want %+v", index+1, got, want) + } + if boolFromTestMap(got.Metrics, "loadAvoided") != want.avoided { + t.Fatalf("loadAvoided mismatch for %s metrics=%+v", got.PlatformName, got.Metrics) + } + if !want.avoided && floatFromTestAny(got.Metrics["loadRatio"]) != 0 { + t.Fatalf("non-full candidate should not carry load pressure, got %s metrics=%+v", got.PlatformName, got.Metrics) + } + if want.avoided { + if ratio := floatFromTestAny(got.Metrics["loadRatio"]); ratio < 1 { + t.Fatalf("avoided candidate should expose full load ratio, got %s ratio=%v metrics=%+v", got.PlatformName, ratio, got.Metrics) + } + loadMetrics, _ := got.Metrics["loadMetrics"].(map[string]any) + concurrent, _ := loadMetrics["concurrent"].(map[string]any) + if ratio := floatFromTestAny(concurrent["ratio"]); ratio < 1 { + t.Fatalf("avoided candidate should expose concurrent saturation, got %s loadMetrics=%+v", got.PlatformName, loadMetrics) + } + if queued := floatFromTestAny(loadMetrics["queued"]); queued < 1 { + t.Fatalf("avoided candidate should expose queued waiting load, got %s loadMetrics=%+v", got.PlatformName, loadMetrics) + } + } + } + if !attemptTraceHasReason(detail.Attempts[0].Metrics, "load_avoidance_fallback") { + t.Fatalf("first hard-stop attempt should continue because avoided candidates remain, metrics=%+v", detail.Attempts[0].Metrics) + } + if summary, ok := detail.Metrics["attempts"].([]any); !ok || len(summary) != len(scenarios) { + t.Fatalf("final task metrics should preserve load-avoidance attempt summary, got %+v", detail.Metrics) + } + t.Logf("load avoidance retry chain: %s", strings.Join(attemptSummary, " -> ")) +} + +func seedQueuedConcurrencyLoad(t *testing.T, ctx context.Context, testPool *pgxpool.Pool, platformKey string, model string, platformModelID string) { + t.Helper() + queueKey := platformKey + ":text_generate:" + model + if _, err := testPool.Exec(ctx, ` +INSERT INTO gateway_tasks ( + kind, run_mode, user_id, model, requested_model, model_type, + request, normalized_request, status, queue_key, priority, async_mode, + next_run_at, result, billings +) +VALUES ( + 'chat.completions', 'simulation', $1, $2, $2, 'text_generate', + '{}'::jsonb, '{}'::jsonb, 'queued', $3, 999, true, + now() + interval '1 hour', '{}'::jsonb, '[]'::jsonb +)`, "load-avoidance-seed-"+platformModelID, model, queueKey); err != nil { + t.Fatalf("seed queued load for %s: %v", queueKey, err) + } +} + +func boolFromTestMap(values map[string]any, key string) bool { + value, _ := values[key].(bool) + return value +} + +func floatFromTestAny(value any) float64 { + switch typed := value.(type) { + case float64: + return typed + case float32: + return float64(typed) + case int: + return float64(typed) + case int64: + return float64(typed) + case json.Number: + out, _ := typed.Float64() + return out + default: + return 0 + } +} + +func attemptTraceHasReason(metrics map[string]any, reason string) bool { + trace, _ := metrics["trace"].([]any) + for _, raw := range trace { + item, _ := raw.(map[string]any) + if item["reason"] == reason { + return true + } + } + return false +} + func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) { t.Helper() if windowSeconds <= 0 { diff --git a/apps/api/internal/runner/recording.go b/apps/api/internal/runner/recording.go index 502ee74..cf6078c 100644 --- a/apps/api/internal/runner/recording.go +++ b/apps/api/internal/runner/recording.go @@ -112,8 +112,30 @@ func attemptMetrics(candidate store.RuntimeModelCandidate, attemptNo int, simula "platformModelId": candidate.PlatformModelID, "clientId": candidate.ClientID, "queueKey": candidate.QueueKey, + "loadRatio": candidate.LoadRatio, + "loadAvoided": candidate.LoadAvoided, "simulated": simulated, } + if candidate.LoadLimited { + metrics["loadMetrics"] = map[string]any{ + "rpm": map[string]any{ + "current": candidate.LoadMetrics.RPMCurrent, + "limit": candidate.LoadMetrics.RPMLimit, + "ratio": candidate.LoadMetrics.RPMRatio, + }, + "tpm": map[string]any{ + "current": candidate.LoadMetrics.TPMCurrent, + "limit": candidate.LoadMetrics.TPMLimit, + "ratio": candidate.LoadMetrics.TPMRatio, + }, + "concurrent": map[string]any{ + "current": candidate.LoadMetrics.ConcurrentCurrent, + "limit": candidate.LoadMetrics.ConcurrentLimit, + "ratio": candidate.LoadMetrics.ConcurrentRatio, + }, + "queued": candidate.LoadMetrics.QueuedCount, + } + } if attemptNo > 0 { metrics["attempt"] = attemptNo } diff --git a/apps/api/internal/runner/retry_decision_test.go b/apps/api/internal/runner/retry_decision_test.go index 413a713..25d3981 100644 --- a/apps/api/internal/runner/retry_decision_test.go +++ b/apps/api/internal/runner/retry_decision_test.go @@ -46,6 +46,24 @@ func TestFailoverTimeBudgetExceeded(t *testing.T) { } } +func TestLoadAvoidanceFallbackContinuesToAvoidedCandidate(t *testing.T) { + candidates := []store.RuntimeModelCandidate{ + {PlatformID: "available-candidate"}, + {PlatformID: "avoided-full-candidate", LoadAvoided: true}, + } + + if !hasLoadAvoidanceFallback(candidates, 0, 99) { + t.Fatal("expected non-avoided candidate to fall back to later avoided candidate") + } + if hasLoadAvoidanceFallback(candidates, 1, 99) { + t.Fatal("avoided candidate should not force another load-avoidance fallback") + } + decision := loadAvoidanceFallbackDecision(&clients.ClientError{Code: "bad_request", StatusCode: 400, Retryable: false}) + if !decision.Retry || decision.Reason != "load_avoidance_fallback" || decision.Action != "next" { + t.Fatalf("expected active load avoidance fallback to force next candidate, got %+v", decision) + } +} + func TestFailoverHardStopBeatsModelOverride(t *testing.T) { runnerPolicy := store.RunnerPolicy{ Status: "active", diff --git a/apps/api/internal/runner/service.go b/apps/api/internal/runner/service.go index ae0d2b3..01cf582 100644 --- a/apps/api/internal/runner/service.go +++ b/apps/api/internal/runner/service.go @@ -290,6 +290,9 @@ candidatesLoop: break } decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr) + if !decision.Retry && hasLoadAvoidanceFallback(candidates, index, maxPlatforms) { + decision = loadAvoidanceFallbackDecision(candidateErr) + } s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision)) if !decision.Retry { break @@ -792,6 +795,37 @@ func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool return maxDuration > 0 && time.Since(start) >= maxDuration } +func hasLoadAvoidanceFallback(candidates []store.RuntimeModelCandidate, index int, maxPlatforms int) bool { + if index < 0 || index >= len(candidates) || candidates[index].LoadAvoided { + return false + } + limit := len(candidates) + if maxPlatforms > 0 && maxPlatforms < limit { + limit = maxPlatforms + } + for next := index + 1; next < limit; next++ { + if candidates[next].LoadAvoided { + return true + } + } + return false +} + +func loadAvoidanceFallbackDecision(err error) failoverDecision { + return failoverDecision{ + Retry: true, + Action: "next", + Reason: "load_avoidance_fallback", + Match: policyRuleMatch{ + Source: "runtime_client_load", + Policy: "loadAvoidance", + Rule: "fallback", + Value: "loadRatio>=1", + }, + Info: failureInfoFromError(err), + } +} + func normalizeRequest(kind string, body map[string]any) map[string]any { out := cloneMap(body) if kind == "responses" && out["messages"] == nil && out["input"] != nil { diff --git a/apps/api/internal/store/candidates.go b/apps/api/internal/store/candidates.go index a01409b..4e33c9e 100644 --- a/apps/api/internal/store/candidates.go +++ b/apps/api/internal/store/candidates.go @@ -3,6 +3,7 @@ package store import ( "context" "fmt" + "sort" "strings" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" @@ -23,10 +24,18 @@ SELECT p.id::text, p.platform_key, p.name, p.provider, COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override, m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''), COALESCE(b.pricing_rule_set_id::text, ''), - m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')), - COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), - COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb), - COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb) + m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')), + COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb), + COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb), + COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb), + COALESCE(con.active, 0)::float8, + COALESCE(queued.waiting, 0)::float8, + COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8, + COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8, + COALESCE(s.running_count, 0)::float8, + COALESCE(s.waiting_count, 0)::float8, + COALESCE(s.limiter_ratio, 0)::float8, + COALESCE(EXTRACT(EPOCH FROM s.last_assigned_at), 0)::float8 FROM platform_models m JOIN integration_platforms p ON p.id = m.platform_id LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider @@ -34,6 +43,57 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id) LEFT JOIN runtime_client_states s ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name) +LEFT JOIN ( + SELECT scope_key, SUM(lease_value) AS active + FROM gateway_concurrency_leases + WHERE scope_type = 'platform_model' + AND released_at IS NULL + AND expires_at > now() + GROUP BY scope_key + ) con ON con.scope_key = m.id::text +LEFT JOIN ( + SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting + FROM ( + SELECT t.id::text AS task_id, qm.id::text AS platform_model_id + FROM gateway_tasks t + JOIN integration_platforms qp ON TRUE + JOIN platform_models qm ON qm.platform_id = qp.id + WHERE t.async_mode = true + AND t.status = 'queued' + AND NULLIF(t.model_type, '') IS NOT NULL + AND qm.model_type @> jsonb_build_array(t.model_type) + AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name) + AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id) + UNION ALL + SELECT latest.task_id::text AS task_id, latest.platform_model_id + FROM ( + SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id + FROM gateway_tasks t + JOIN gateway_task_attempts a ON a.task_id = t.id + WHERE t.async_mode = true + AND t.status = 'queued' + AND a.platform_model_id IS NOT NULL + ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC + ) latest + ) queued_sources + GROUP BY queued_sources.platform_model_id + ) queued ON queued.platform_model_id = m.id::text +LEFT JOIN ( + SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value + FROM gateway_rate_limit_counters + WHERE scope_type = 'platform_model' + AND metric = 'rpm' + AND reset_at > now() + ORDER BY scope_key, window_start DESC + ) rpm ON rpm.scope_key = m.id::text +LEFT JOIN ( + SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value + FROM gateway_rate_limit_counters + WHERE scope_type = 'platform_model' + AND metric LIKE 'tpm%' + AND reset_at > now() + GROUP BY scope_key + ) tpm ON tpm.scope_key = m.id::text WHERE p.status = 'enabled' AND p.deleted_at IS NULL AND m.enabled = true @@ -52,7 +112,6 @@ WHERE p.status = 'enabled' ) ) ORDER BY effective_priority ASC, - COALESCE(s.limiter_ratio, 0) ASC, COALESCE(s.running_count, 0) ASC, COALESCE(s.waiting_count, 0) ASC, COALESCE(s.last_assigned_at, to_timestamp(0)) ASC, @@ -82,6 +141,16 @@ ORDER BY effective_priority ASC, var runtimeRateLimitPolicy []byte var autoDisablePolicy []byte var degradePolicy []byte + var concurrentActive float64 + var queuedWaiting float64 + var rpmUsed float64 + var rpmReserved float64 + var tpmUsed float64 + var tpmReserved float64 + var stateRunningCount float64 + var stateWaitingCount float64 + var stateLimiterRatio float64 + var lastAssignedUnix float64 if err := rows.Scan( &item.PlatformID, &item.PlatformKey, @@ -124,6 +193,16 @@ ORDER BY effective_priority ASC, &runtimeRateLimitPolicy, &autoDisablePolicy, °radePolicy, + &concurrentActive, + &queuedWaiting, + &rpmUsed, + &rpmReserved, + &tpmUsed, + &tpmReserved, + &stateRunningCount, + &stateWaitingCount, + &stateLimiterRatio, + &lastAssignedUnix, ); err != nil { return nil, err } @@ -147,6 +226,21 @@ ORDER BY effective_priority ASC, upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName) item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName) item.QueueKey = item.ClientID + item.RunningCount = stateRunningCount + item.WaitingCount = maxFloat(queuedWaiting, stateWaitingCount) + item.LastAssignedUnix = lastAssignedUnix + applyRuntimeCandidateLoad(&item, runtimeCandidateLoadInput{ + Policy: effectiveModelRateLimitPolicy(item.PlatformRateLimitPolicy, item.RuntimeRateLimitPolicy, item.RuntimePolicyOverride, item.ModelRateLimitPolicy), + ConcurrentActive: concurrentActive, + QueuedWaiting: queuedWaiting, + RPMUsed: rpmUsed, + RPMReserved: rpmReserved, + TPMUsed: tpmUsed, + TPMReserved: tpmReserved, + StateRunningCount: stateRunningCount, + StateWaitingCount: stateWaitingCount, + StateLimiterRatio: stateLimiterRatio, + }) items = append(items, item) } if err := rows.Err(); err != nil { @@ -167,9 +261,102 @@ ORDER BY effective_priority ASC, if len(items) == 0 { return nil, ErrNoModelCandidate } + sortRuntimeModelCandidates(items) return items, nil } +type runtimeCandidateLoadInput struct { + Policy map[string]any + ConcurrentActive float64 + QueuedWaiting float64 + RPMUsed float64 + RPMReserved float64 + TPMUsed float64 + TPMReserved float64 + StateRunningCount float64 + StateWaitingCount float64 + StateLimiterRatio float64 +} + +func applyRuntimeCandidateLoad(candidate *RuntimeModelCandidate, input runtimeCandidateLoadInput) { + rpmLimit := rateLimitForMetric(input.Policy, "rpm") + tpmLimitValue := tpmLimit(input.Policy) + concurrentLimit := rateLimitForMetric(input.Policy, "concurrent") + rpmCurrent := input.RPMUsed + input.RPMReserved + tpmCurrent := input.TPMUsed + input.TPMReserved + concurrentCurrent := input.ConcurrentActive + input.QueuedWaiting + metrics := RuntimeCandidateLoadMetrics{ + RPMCurrent: rpmCurrent, + RPMLimit: rpmLimit, + RPMRatio: ratioIfLimited(rpmCurrent, rpmLimit), + TPMCurrent: tpmCurrent, + TPMLimit: tpmLimitValue, + TPMRatio: ratioIfLimited(tpmCurrent, tpmLimitValue), + ConcurrentCurrent: concurrentCurrent, + ConcurrentLimit: concurrentLimit, + ConcurrentRatio: ratioIfLimited(concurrentCurrent, concurrentLimit), + QueuedCount: input.QueuedWaiting, + StateRunningCount: input.StateRunningCount, + StateWaitingCount: input.StateWaitingCount, + StateLimiterRatio: input.StateLimiterRatio, + } + candidate.LoadMetrics = metrics + candidate.LoadLimited = rpmLimit > 0 || tpmLimitValue > 0 || concurrentLimit > 0 + candidate.LoadRatio = maxFloat(metrics.RPMRatio, metrics.TPMRatio, metrics.ConcurrentRatio) +} + +func ratioIfLimited(current float64, limit float64) float64 { + if limit <= 0 { + return 0 + } + return current / limit +} + +func sortRuntimeModelCandidates(items []RuntimeModelCandidate) { + hasFull := false + hasNonFull := false + for index := range items { + items[index].LoadAvoided = false + if runtimeCandidateFull(items[index]) { + hasFull = true + } else { + hasNonFull = true + } + } + if hasFull && hasNonFull { + for index := range items { + items[index].LoadAvoided = runtimeCandidateFull(items[index]) + } + } + sort.SliceStable(items, func(i, j int) bool { + aFull := runtimeCandidateFull(items[i]) + bFull := runtimeCandidateFull(items[j]) + if aFull != bFull { + return !aFull + } + if items[i].PlatformPriority != items[j].PlatformPriority { + return items[i].PlatformPriority < items[j].PlatformPriority + } + if items[i].LoadRatio != items[j].LoadRatio { + return items[i].LoadRatio < items[j].LoadRatio + } + if items[i].RunningCount != items[j].RunningCount { + return items[i].RunningCount < items[j].RunningCount + } + if items[i].WaitingCount != items[j].WaitingCount { + return items[i].WaitingCount < items[j].WaitingCount + } + if items[i].LastAssignedUnix != items[j].LastAssignedUnix { + return items[i].LastAssignedUnix < items[j].LastAssignedUnix + } + return false + }) +} + +func runtimeCandidateFull(candidate RuntimeModelCandidate) bool { + return candidate.LoadLimited && candidate.LoadRatio >= 1 +} + func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) { rows, err := s.pool.Query(ctx, ` SELECT p.name, diff --git a/apps/api/internal/store/candidates_test.go b/apps/api/internal/store/candidates_test.go new file mode 100644 index 0000000..18659bd --- /dev/null +++ b/apps/api/internal/store/candidates_test.go @@ -0,0 +1,61 @@ +package store + +import "testing" + +func TestRuntimeCandidateLoadUsesMaxLimitedMetric(t *testing.T) { + candidate := RuntimeModelCandidate{} + applyRuntimeCandidateLoad(&candidate, runtimeCandidateLoadInput{ + Policy: map[string]any{"rules": []any{ + map[string]any{"metric": "rpm", "limit": 100}, + map[string]any{"metric": "tpm_total", "limit": 1000}, + map[string]any{"metric": "concurrent", "limit": 10}, + }}, + RPMUsed: 40, + RPMReserved: 10, + TPMUsed: 900, + ConcurrentActive: 3, + QueuedWaiting: 2, + }) + + if !candidate.LoadLimited { + t.Fatal("expected load to be limited when rate limit rules exist") + } + if candidate.LoadMetrics.RPMRatio != 0.5 { + t.Fatalf("expected rpm ratio 0.5, got %v", candidate.LoadMetrics.RPMRatio) + } + if candidate.LoadMetrics.TPMRatio != 0.9 { + t.Fatalf("expected tpm ratio 0.9, got %v", candidate.LoadMetrics.TPMRatio) + } + if candidate.LoadMetrics.ConcurrentRatio != 0.5 { + t.Fatalf("expected concurrent ratio 0.5, got %v", candidate.LoadMetrics.ConcurrentRatio) + } + if candidate.LoadRatio != 0.9 { + t.Fatalf("expected max load ratio 0.9, got %v", candidate.LoadRatio) + } +} + +func TestRuntimeCandidateSortingAvoidsFullCandidatesButKeepsFallback(t *testing.T) { + candidates := []RuntimeModelCandidate{ + { + PlatformID: "high-priority-full", + PlatformPriority: 1, + LoadLimited: true, + LoadRatio: 1.2, + }, + { + PlatformID: "lower-priority-available", + PlatformPriority: 50, + LoadLimited: true, + LoadRatio: 0.2, + }, + } + + sortRuntimeModelCandidates(candidates) + + if candidates[0].PlatformID != "lower-priority-available" { + t.Fatalf("expected non-full candidate to be tried first, got %+v", candidates) + } + if candidates[1].PlatformID != "high-priority-full" || !candidates[1].LoadAvoided { + t.Fatalf("expected full high-priority candidate to remain as avoided fallback, got %+v", candidates) + } +} diff --git a/apps/api/internal/store/runtime_types.go b/apps/api/internal/store/runtime_types.go index fd80fab..d2784a7 100644 --- a/apps/api/internal/store/runtime_types.go +++ b/apps/api/internal/store/runtime_types.go @@ -137,6 +137,29 @@ type RuntimeModelCandidate struct { DegradePolicy map[string]any ClientID string QueueKey string + LoadRatio float64 + LoadLimited bool + LoadAvoided bool + LoadMetrics RuntimeCandidateLoadMetrics + RunningCount float64 + WaitingCount float64 + LastAssignedUnix float64 +} + +type RuntimeCandidateLoadMetrics struct { + RPMCurrent float64 + RPMLimit float64 + RPMRatio float64 + TPMCurrent float64 + TPMLimit float64 + TPMRatio float64 + ConcurrentCurrent float64 + ConcurrentLimit float64 + ConcurrentRatio float64 + QueuedCount float64 + StateRunningCount float64 + StateWaitingCount float64 + StateLimiterRatio float64 } type RateLimitReservation struct {