feat(api): add load-aware client fallback
This commit is contained in:
parent
c5cede2359
commit
c2696e7bbe
@ -1193,6 +1193,8 @@ WHERE m.platform_id = $1::uuid
|
||||
t.Fatalf("failover events should include retrying event status=%d body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
assertLoadAvoidanceSimulatedRetryChain(t, ctx, testPool, server.URL, loginResponse.AccessToken, apiKeyResponse.Secret, suffixText)
|
||||
|
||||
var callbackRows int
|
||||
if err := testPool.QueryRow(ctx, `SELECT count(*) FROM gateway_task_callback_outbox WHERE task_id = $1::uuid`, taskResponse.Task.ID).Scan(&callbackRows); err != nil {
|
||||
t.Fatalf("read callback outbox: %v", err)
|
||||
@ -1396,6 +1398,212 @@ func waitForTaskStatus(t *testing.T, baseURL string, token string, taskID string
|
||||
return detail
|
||||
}
|
||||
|
||||
func assertLoadAvoidanceSimulatedRetryChain(t *testing.T, ctx context.Context, testPool *pgxpool.Pool, baseURL string, adminToken string, runtimeToken string, suffixText string) {
|
||||
t.Helper()
|
||||
model := "load-avoidance-smoke-" + suffixText
|
||||
type scenario struct {
|
||||
keySuffix string
|
||||
name string
|
||||
failure string
|
||||
priority int
|
||||
full bool
|
||||
}
|
||||
scenarios := []scenario{
|
||||
{keySuffix: "hard-stop", name: "Load Avoidance Hard Stop", failure: "fatal_failure", priority: 20},
|
||||
{keySuffix: "retryable", name: "Load Avoidance Retryable", failure: "retryable_failure", priority: 30},
|
||||
{keySuffix: "full-rate-limit", name: "Load Avoidance Full Rate Limit", failure: "rate_limit", priority: 1, full: true},
|
||||
{keySuffix: "full-overloaded", name: "Load Avoidance Full Overloaded", failure: "overloaded", priority: 2, full: true},
|
||||
{keySuffix: "full-fatal", name: "Load Avoidance Full Fatal", failure: "fatal_failure", priority: 3, full: true},
|
||||
}
|
||||
for _, item := range scenarios {
|
||||
var platform struct {
|
||||
ID string `json:"id"`
|
||||
PlatformKey string `json:"platformKey"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
credentials := map[string]any{"mode": "simulation"}
|
||||
if item.failure != "" {
|
||||
credentials["simulationFailure"] = item.failure
|
||||
}
|
||||
doJSON(t, baseURL, http.MethodPost, "/api/admin/platforms", adminToken, map[string]any{
|
||||
"provider": "openai",
|
||||
"platformKey": "openai-load-" + item.keySuffix + "-" + suffixText,
|
||||
"name": item.name,
|
||||
"baseUrl": "https://api.openai.com/v1",
|
||||
"authType": "bearer",
|
||||
"credentials": credentials,
|
||||
"priority": item.priority,
|
||||
}, http.StatusCreated, &platform)
|
||||
if platform.ID == "" || platform.PlatformKey == "" {
|
||||
t.Fatalf("load avoidance platform was not created: %+v", platform)
|
||||
}
|
||||
|
||||
var platformModel struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
payload := map[string]any{
|
||||
"canonicalModelKey": "openai:gpt-4o-mini",
|
||||
"modelName": model,
|
||||
"providerModelName": model,
|
||||
"modelAlias": model,
|
||||
"modelType": []string{"text_generate"},
|
||||
"displayName": item.name,
|
||||
"retryPolicy": map[string]any{"enabled": true, "maxAttempts": 1},
|
||||
}
|
||||
if item.full {
|
||||
payload["rateLimitPolicy"] = map[string]any{
|
||||
"rules": []map[string]any{
|
||||
{"metric": "rpm", "limit": 10, "windowSeconds": 60},
|
||||
{"metric": "tpm_total", "limit": 200, "windowSeconds": 60},
|
||||
{"metric": "concurrent", "limit": 1, "leaseTtlSeconds": 120},
|
||||
},
|
||||
}
|
||||
}
|
||||
doJSON(t, baseURL, http.MethodPost, "/api/admin/platforms/"+platform.ID+"/models", adminToken, payload, http.StatusCreated, &platformModel)
|
||||
if platformModel.ID == "" {
|
||||
t.Fatalf("load avoidance platform model was not created for %s: %+v", item.name, platformModel)
|
||||
}
|
||||
if item.full {
|
||||
seedQueuedConcurrencyLoad(t, ctx, testPool, platform.PlatformKey, model, platformModel.ID)
|
||||
}
|
||||
}
|
||||
|
||||
var taskResponse struct {
|
||||
Task struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
} `json:"task"`
|
||||
}
|
||||
doJSON(t, baseURL, http.MethodPost, "/api/v1/chat/completions", runtimeToken, map[string]any{
|
||||
"model": model,
|
||||
"runMode": "simulation",
|
||||
"simulation": true,
|
||||
"simulationDurationMs": 5,
|
||||
"messages": []map[string]any{{"role": "user", "content": "load avoidance retry chain"}},
|
||||
}, http.StatusAccepted, &taskResponse)
|
||||
if taskResponse.Task.ID == "" || taskResponse.Task.Status != "failed" || taskResponse.Task.ErrorCode != "bad_request" {
|
||||
t.Fatalf("load avoidance task should only fail after avoided clients are retried, got %+v", taskResponse.Task)
|
||||
}
|
||||
|
||||
var detail struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
Attempts []struct {
|
||||
AttemptNo int `json:"attemptNo"`
|
||||
PlatformName string `json:"platformName"`
|
||||
Status string `json:"status"`
|
||||
Retryable bool `json:"retryable"`
|
||||
ErrorCode string `json:"errorCode"`
|
||||
Metrics map[string]any `json:"metrics"`
|
||||
} `json:"attempts"`
|
||||
Metrics map[string]any `json:"metrics"`
|
||||
}
|
||||
doJSON(t, baseURL, http.MethodGet, "/api/v1/tasks/"+taskResponse.Task.ID, runtimeToken, nil, http.StatusOK, &detail)
|
||||
if detail.Status != "failed" || len(detail.Attempts) != len(scenarios) {
|
||||
t.Fatalf("load avoidance detail should expose every attempted client, got status=%s attempts=%+v", detail.Status, detail.Attempts)
|
||||
}
|
||||
expected := []struct {
|
||||
name string
|
||||
code string
|
||||
retryable bool
|
||||
avoided bool
|
||||
}{
|
||||
{name: "Load Avoidance Hard Stop", code: "bad_request"},
|
||||
{name: "Load Avoidance Retryable", code: "server_error", retryable: true},
|
||||
{name: "Load Avoidance Full Rate Limit", code: "rate_limit", retryable: true, avoided: true},
|
||||
{name: "Load Avoidance Full Overloaded", code: "overloaded", retryable: true, avoided: true},
|
||||
{name: "Load Avoidance Full Fatal", code: "bad_request", avoided: true},
|
||||
}
|
||||
attemptSummary := make([]string, 0, len(detail.Attempts))
|
||||
for index, want := range expected {
|
||||
got := detail.Attempts[index]
|
||||
attemptSummary = append(attemptSummary, got.PlatformName+":"+got.Status+":"+got.ErrorCode)
|
||||
if got.AttemptNo != index+1 || got.PlatformName != want.name || got.Status != "failed" || got.ErrorCode != want.code || got.Retryable != want.retryable {
|
||||
t.Fatalf("unexpected load avoidance attempt %d: got %+v want %+v", index+1, got, want)
|
||||
}
|
||||
if boolFromTestMap(got.Metrics, "loadAvoided") != want.avoided {
|
||||
t.Fatalf("loadAvoided mismatch for %s metrics=%+v", got.PlatformName, got.Metrics)
|
||||
}
|
||||
if !want.avoided && floatFromTestAny(got.Metrics["loadRatio"]) != 0 {
|
||||
t.Fatalf("non-full candidate should not carry load pressure, got %s metrics=%+v", got.PlatformName, got.Metrics)
|
||||
}
|
||||
if want.avoided {
|
||||
if ratio := floatFromTestAny(got.Metrics["loadRatio"]); ratio < 1 {
|
||||
t.Fatalf("avoided candidate should expose full load ratio, got %s ratio=%v metrics=%+v", got.PlatformName, ratio, got.Metrics)
|
||||
}
|
||||
loadMetrics, _ := got.Metrics["loadMetrics"].(map[string]any)
|
||||
concurrent, _ := loadMetrics["concurrent"].(map[string]any)
|
||||
if ratio := floatFromTestAny(concurrent["ratio"]); ratio < 1 {
|
||||
t.Fatalf("avoided candidate should expose concurrent saturation, got %s loadMetrics=%+v", got.PlatformName, loadMetrics)
|
||||
}
|
||||
if queued := floatFromTestAny(loadMetrics["queued"]); queued < 1 {
|
||||
t.Fatalf("avoided candidate should expose queued waiting load, got %s loadMetrics=%+v", got.PlatformName, loadMetrics)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !attemptTraceHasReason(detail.Attempts[0].Metrics, "load_avoidance_fallback") {
|
||||
t.Fatalf("first hard-stop attempt should continue because avoided candidates remain, metrics=%+v", detail.Attempts[0].Metrics)
|
||||
}
|
||||
if summary, ok := detail.Metrics["attempts"].([]any); !ok || len(summary) != len(scenarios) {
|
||||
t.Fatalf("final task metrics should preserve load-avoidance attempt summary, got %+v", detail.Metrics)
|
||||
}
|
||||
t.Logf("load avoidance retry chain: %s", strings.Join(attemptSummary, " -> "))
|
||||
}
|
||||
|
||||
func seedQueuedConcurrencyLoad(t *testing.T, ctx context.Context, testPool *pgxpool.Pool, platformKey string, model string, platformModelID string) {
|
||||
t.Helper()
|
||||
queueKey := platformKey + ":text_generate:" + model
|
||||
if _, err := testPool.Exec(ctx, `
|
||||
INSERT INTO gateway_tasks (
|
||||
kind, run_mode, user_id, model, requested_model, model_type,
|
||||
request, normalized_request, status, queue_key, priority, async_mode,
|
||||
next_run_at, result, billings
|
||||
)
|
||||
VALUES (
|
||||
'chat.completions', 'simulation', $1, $2, $2, 'text_generate',
|
||||
'{}'::jsonb, '{}'::jsonb, 'queued', $3, 999, true,
|
||||
now() + interval '1 hour', '{}'::jsonb, '[]'::jsonb
|
||||
)`, "load-avoidance-seed-"+platformModelID, model, queueKey); err != nil {
|
||||
t.Fatalf("seed queued load for %s: %v", queueKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
func boolFromTestMap(values map[string]any, key string) bool {
|
||||
value, _ := values[key].(bool)
|
||||
return value
|
||||
}
|
||||
|
||||
func floatFromTestAny(value any) float64 {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed
|
||||
case float32:
|
||||
return float64(typed)
|
||||
case int:
|
||||
return float64(typed)
|
||||
case int64:
|
||||
return float64(typed)
|
||||
case json.Number:
|
||||
out, _ := typed.Float64()
|
||||
return out
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func attemptTraceHasReason(metrics map[string]any, reason string) bool {
|
||||
trace, _ := metrics["trace"].([]any)
|
||||
for _, raw := range trace {
|
||||
item, _ := raw.(map[string]any)
|
||||
if item["reason"] == reason {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func waitForRateLimitWindowHead(t *testing.T, windowSeconds int) {
|
||||
t.Helper()
|
||||
if windowSeconds <= 0 {
|
||||
|
||||
@ -112,8 +112,30 @@ func attemptMetrics(candidate store.RuntimeModelCandidate, attemptNo int, simula
|
||||
"platformModelId": candidate.PlatformModelID,
|
||||
"clientId": candidate.ClientID,
|
||||
"queueKey": candidate.QueueKey,
|
||||
"loadRatio": candidate.LoadRatio,
|
||||
"loadAvoided": candidate.LoadAvoided,
|
||||
"simulated": simulated,
|
||||
}
|
||||
if candidate.LoadLimited {
|
||||
metrics["loadMetrics"] = map[string]any{
|
||||
"rpm": map[string]any{
|
||||
"current": candidate.LoadMetrics.RPMCurrent,
|
||||
"limit": candidate.LoadMetrics.RPMLimit,
|
||||
"ratio": candidate.LoadMetrics.RPMRatio,
|
||||
},
|
||||
"tpm": map[string]any{
|
||||
"current": candidate.LoadMetrics.TPMCurrent,
|
||||
"limit": candidate.LoadMetrics.TPMLimit,
|
||||
"ratio": candidate.LoadMetrics.TPMRatio,
|
||||
},
|
||||
"concurrent": map[string]any{
|
||||
"current": candidate.LoadMetrics.ConcurrentCurrent,
|
||||
"limit": candidate.LoadMetrics.ConcurrentLimit,
|
||||
"ratio": candidate.LoadMetrics.ConcurrentRatio,
|
||||
},
|
||||
"queued": candidate.LoadMetrics.QueuedCount,
|
||||
}
|
||||
}
|
||||
if attemptNo > 0 {
|
||||
metrics["attempt"] = attemptNo
|
||||
}
|
||||
|
||||
@ -46,6 +46,24 @@ func TestFailoverTimeBudgetExceeded(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAvoidanceFallbackContinuesToAvoidedCandidate(t *testing.T) {
|
||||
candidates := []store.RuntimeModelCandidate{
|
||||
{PlatformID: "available-candidate"},
|
||||
{PlatformID: "avoided-full-candidate", LoadAvoided: true},
|
||||
}
|
||||
|
||||
if !hasLoadAvoidanceFallback(candidates, 0, 99) {
|
||||
t.Fatal("expected non-avoided candidate to fall back to later avoided candidate")
|
||||
}
|
||||
if hasLoadAvoidanceFallback(candidates, 1, 99) {
|
||||
t.Fatal("avoided candidate should not force another load-avoidance fallback")
|
||||
}
|
||||
decision := loadAvoidanceFallbackDecision(&clients.ClientError{Code: "bad_request", StatusCode: 400, Retryable: false})
|
||||
if !decision.Retry || decision.Reason != "load_avoidance_fallback" || decision.Action != "next" {
|
||||
t.Fatalf("expected active load avoidance fallback to force next candidate, got %+v", decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailoverHardStopBeatsModelOverride(t *testing.T) {
|
||||
runnerPolicy := store.RunnerPolicy{
|
||||
Status: "active",
|
||||
|
||||
@ -290,6 +290,9 @@ candidatesLoop:
|
||||
break
|
||||
}
|
||||
decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr)
|
||||
if !decision.Retry && hasLoadAvoidanceFallback(candidates, index, maxPlatforms) {
|
||||
decision = loadAvoidanceFallbackDecision(candidateErr)
|
||||
}
|
||||
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision))
|
||||
if !decision.Retry {
|
||||
break
|
||||
@ -792,6 +795,37 @@ func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool
|
||||
return maxDuration > 0 && time.Since(start) >= maxDuration
|
||||
}
|
||||
|
||||
func hasLoadAvoidanceFallback(candidates []store.RuntimeModelCandidate, index int, maxPlatforms int) bool {
|
||||
if index < 0 || index >= len(candidates) || candidates[index].LoadAvoided {
|
||||
return false
|
||||
}
|
||||
limit := len(candidates)
|
||||
if maxPlatforms > 0 && maxPlatforms < limit {
|
||||
limit = maxPlatforms
|
||||
}
|
||||
for next := index + 1; next < limit; next++ {
|
||||
if candidates[next].LoadAvoided {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadAvoidanceFallbackDecision(err error) failoverDecision {
|
||||
return failoverDecision{
|
||||
Retry: true,
|
||||
Action: "next",
|
||||
Reason: "load_avoidance_fallback",
|
||||
Match: policyRuleMatch{
|
||||
Source: "runtime_client_load",
|
||||
Policy: "loadAvoidance",
|
||||
Rule: "fallback",
|
||||
Value: "loadRatio>=1",
|
||||
},
|
||||
Info: failureInfoFromError(err),
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRequest(kind string, body map[string]any) map[string]any {
|
||||
out := cloneMap(body)
|
||||
if kind == "responses" && out["messages"] == nil && out["input"] != nil {
|
||||
|
||||
@ -3,6 +3,7 @@ package store
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
@ -23,10 +24,18 @@ SELECT p.id::text, p.platform_key, p.name, p.provider,
|
||||
COALESCE(b.base_billing_config, '{}'::jsonb), m.billing_config, m.billing_config_override,
|
||||
m.pricing_mode, COALESCE(m.discount_factor, 0)::float8, COALESCE(m.pricing_rule_set_id::text, ''),
|
||||
COALESCE(b.pricing_rule_set_id::text, ''),
|
||||
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
|
||||
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
|
||||
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
|
||||
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb)
|
||||
m.permission_config, m.retry_policy, m.rate_limit_policy, COALESCE(m.runtime_policy_set_id::text, COALESCE(b.runtime_policy_set_id::text, '')),
|
||||
COALESCE(NULLIF(m.runtime_policy_override, '{}'::jsonb), b.runtime_policy_override, '{}'::jsonb),
|
||||
COALESCE(rp.retry_policy, '{}'::jsonb), COALESCE(rp.rate_limit_policy, '{}'::jsonb),
|
||||
COALESCE(rp.auto_disable_policy, '{}'::jsonb), COALESCE(rp.degrade_policy, '{}'::jsonb),
|
||||
COALESCE(con.active, 0)::float8,
|
||||
COALESCE(queued.waiting, 0)::float8,
|
||||
COALESCE(rpm.used_value, 0)::float8, COALESCE(rpm.reserved_value, 0)::float8,
|
||||
COALESCE(tpm.used_value, 0)::float8, COALESCE(tpm.reserved_value, 0)::float8,
|
||||
COALESCE(s.running_count, 0)::float8,
|
||||
COALESCE(s.waiting_count, 0)::float8,
|
||||
COALESCE(s.limiter_ratio, 0)::float8,
|
||||
COALESCE(EXTRACT(EPOCH FROM s.last_assigned_at), 0)::float8
|
||||
FROM platform_models m
|
||||
JOIN integration_platforms p ON p.id = m.platform_id
|
||||
LEFT JOIN model_catalog_providers cp ON cp.provider_key = p.provider OR cp.provider_code = p.provider
|
||||
@ -34,6 +43,57 @@ LEFT JOIN base_model_catalog b ON b.id = m.base_model_id
|
||||
LEFT JOIN model_runtime_policy_sets rp ON rp.id = COALESCE(m.runtime_policy_set_id, b.runtime_policy_set_id)
|
||||
LEFT JOIN runtime_client_states s
|
||||
ON s.client_id = p.platform_key || ':' || $2::text || ':' || COALESCE(NULLIF(m.provider_model_name, ''), m.model_name)
|
||||
LEFT JOIN (
|
||||
SELECT scope_key, SUM(lease_value) AS active
|
||||
FROM gateway_concurrency_leases
|
||||
WHERE scope_type = 'platform_model'
|
||||
AND released_at IS NULL
|
||||
AND expires_at > now()
|
||||
GROUP BY scope_key
|
||||
) con ON con.scope_key = m.id::text
|
||||
LEFT JOIN (
|
||||
SELECT queued_sources.platform_model_id, COUNT(DISTINCT queued_sources.task_id) AS waiting
|
||||
FROM (
|
||||
SELECT t.id::text AS task_id, qm.id::text AS platform_model_id
|
||||
FROM gateway_tasks t
|
||||
JOIN integration_platforms qp ON TRUE
|
||||
JOIN platform_models qm ON qm.platform_id = qp.id
|
||||
WHERE t.async_mode = true
|
||||
AND t.status = 'queued'
|
||||
AND NULLIF(t.model_type, '') IS NOT NULL
|
||||
AND qm.model_type @> jsonb_build_array(t.model_type)
|
||||
AND t.queue_key = qp.platform_key || ':' || t.model_type || ':' || COALESCE(NULLIF(qm.provider_model_name, ''), qm.model_name)
|
||||
AND NOT EXISTS (SELECT 1 FROM gateway_task_attempts existing_attempt WHERE existing_attempt.task_id = t.id)
|
||||
UNION ALL
|
||||
SELECT latest.task_id::text AS task_id, latest.platform_model_id
|
||||
FROM (
|
||||
SELECT DISTINCT ON (a.task_id) a.task_id, a.platform_model_id::text AS platform_model_id
|
||||
FROM gateway_tasks t
|
||||
JOIN gateway_task_attempts a ON a.task_id = t.id
|
||||
WHERE t.async_mode = true
|
||||
AND t.status = 'queued'
|
||||
AND a.platform_model_id IS NOT NULL
|
||||
ORDER BY a.task_id, a.attempt_no DESC, a.started_at DESC
|
||||
) latest
|
||||
) queued_sources
|
||||
GROUP BY queued_sources.platform_model_id
|
||||
) queued ON queued.platform_model_id = m.id::text
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON (scope_key) scope_key, used_value, reserved_value
|
||||
FROM gateway_rate_limit_counters
|
||||
WHERE scope_type = 'platform_model'
|
||||
AND metric = 'rpm'
|
||||
AND reset_at > now()
|
||||
ORDER BY scope_key, window_start DESC
|
||||
) rpm ON rpm.scope_key = m.id::text
|
||||
LEFT JOIN (
|
||||
SELECT scope_key, SUM(used_value) AS used_value, SUM(reserved_value) AS reserved_value
|
||||
FROM gateway_rate_limit_counters
|
||||
WHERE scope_type = 'platform_model'
|
||||
AND metric LIKE 'tpm%'
|
||||
AND reset_at > now()
|
||||
GROUP BY scope_key
|
||||
) tpm ON tpm.scope_key = m.id::text
|
||||
WHERE p.status = 'enabled'
|
||||
AND p.deleted_at IS NULL
|
||||
AND m.enabled = true
|
||||
@ -52,7 +112,6 @@ WHERE p.status = 'enabled'
|
||||
)
|
||||
)
|
||||
ORDER BY effective_priority ASC,
|
||||
COALESCE(s.limiter_ratio, 0) ASC,
|
||||
COALESCE(s.running_count, 0) ASC,
|
||||
COALESCE(s.waiting_count, 0) ASC,
|
||||
COALESCE(s.last_assigned_at, to_timestamp(0)) ASC,
|
||||
@ -82,6 +141,16 @@ ORDER BY effective_priority ASC,
|
||||
var runtimeRateLimitPolicy []byte
|
||||
var autoDisablePolicy []byte
|
||||
var degradePolicy []byte
|
||||
var concurrentActive float64
|
||||
var queuedWaiting float64
|
||||
var rpmUsed float64
|
||||
var rpmReserved float64
|
||||
var tpmUsed float64
|
||||
var tpmReserved float64
|
||||
var stateRunningCount float64
|
||||
var stateWaitingCount float64
|
||||
var stateLimiterRatio float64
|
||||
var lastAssignedUnix float64
|
||||
if err := rows.Scan(
|
||||
&item.PlatformID,
|
||||
&item.PlatformKey,
|
||||
@ -124,6 +193,16 @@ ORDER BY effective_priority ASC,
|
||||
&runtimeRateLimitPolicy,
|
||||
&autoDisablePolicy,
|
||||
°radePolicy,
|
||||
&concurrentActive,
|
||||
&queuedWaiting,
|
||||
&rpmUsed,
|
||||
&rpmReserved,
|
||||
&tpmUsed,
|
||||
&tpmReserved,
|
||||
&stateRunningCount,
|
||||
&stateWaitingCount,
|
||||
&stateLimiterRatio,
|
||||
&lastAssignedUnix,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -147,6 +226,21 @@ ORDER BY effective_priority ASC,
|
||||
upstreamModelName := firstNonEmpty(item.ProviderModelName, item.ModelName)
|
||||
item.ClientID = fmt.Sprintf("%s:%s:%s", item.PlatformKey, item.ModelType, upstreamModelName)
|
||||
item.QueueKey = item.ClientID
|
||||
item.RunningCount = stateRunningCount
|
||||
item.WaitingCount = maxFloat(queuedWaiting, stateWaitingCount)
|
||||
item.LastAssignedUnix = lastAssignedUnix
|
||||
applyRuntimeCandidateLoad(&item, runtimeCandidateLoadInput{
|
||||
Policy: effectiveModelRateLimitPolicy(item.PlatformRateLimitPolicy, item.RuntimeRateLimitPolicy, item.RuntimePolicyOverride, item.ModelRateLimitPolicy),
|
||||
ConcurrentActive: concurrentActive,
|
||||
QueuedWaiting: queuedWaiting,
|
||||
RPMUsed: rpmUsed,
|
||||
RPMReserved: rpmReserved,
|
||||
TPMUsed: tpmUsed,
|
||||
TPMReserved: tpmReserved,
|
||||
StateRunningCount: stateRunningCount,
|
||||
StateWaitingCount: stateWaitingCount,
|
||||
StateLimiterRatio: stateLimiterRatio,
|
||||
})
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@ -167,9 +261,102 @@ ORDER BY effective_priority ASC,
|
||||
if len(items) == 0 {
|
||||
return nil, ErrNoModelCandidate
|
||||
}
|
||||
sortRuntimeModelCandidates(items)
|
||||
return items, nil
|
||||
}
|
||||
|
||||
type runtimeCandidateLoadInput struct {
|
||||
Policy map[string]any
|
||||
ConcurrentActive float64
|
||||
QueuedWaiting float64
|
||||
RPMUsed float64
|
||||
RPMReserved float64
|
||||
TPMUsed float64
|
||||
TPMReserved float64
|
||||
StateRunningCount float64
|
||||
StateWaitingCount float64
|
||||
StateLimiterRatio float64
|
||||
}
|
||||
|
||||
func applyRuntimeCandidateLoad(candidate *RuntimeModelCandidate, input runtimeCandidateLoadInput) {
|
||||
rpmLimit := rateLimitForMetric(input.Policy, "rpm")
|
||||
tpmLimitValue := tpmLimit(input.Policy)
|
||||
concurrentLimit := rateLimitForMetric(input.Policy, "concurrent")
|
||||
rpmCurrent := input.RPMUsed + input.RPMReserved
|
||||
tpmCurrent := input.TPMUsed + input.TPMReserved
|
||||
concurrentCurrent := input.ConcurrentActive + input.QueuedWaiting
|
||||
metrics := RuntimeCandidateLoadMetrics{
|
||||
RPMCurrent: rpmCurrent,
|
||||
RPMLimit: rpmLimit,
|
||||
RPMRatio: ratioIfLimited(rpmCurrent, rpmLimit),
|
||||
TPMCurrent: tpmCurrent,
|
||||
TPMLimit: tpmLimitValue,
|
||||
TPMRatio: ratioIfLimited(tpmCurrent, tpmLimitValue),
|
||||
ConcurrentCurrent: concurrentCurrent,
|
||||
ConcurrentLimit: concurrentLimit,
|
||||
ConcurrentRatio: ratioIfLimited(concurrentCurrent, concurrentLimit),
|
||||
QueuedCount: input.QueuedWaiting,
|
||||
StateRunningCount: input.StateRunningCount,
|
||||
StateWaitingCount: input.StateWaitingCount,
|
||||
StateLimiterRatio: input.StateLimiterRatio,
|
||||
}
|
||||
candidate.LoadMetrics = metrics
|
||||
candidate.LoadLimited = rpmLimit > 0 || tpmLimitValue > 0 || concurrentLimit > 0
|
||||
candidate.LoadRatio = maxFloat(metrics.RPMRatio, metrics.TPMRatio, metrics.ConcurrentRatio)
|
||||
}
|
||||
|
||||
func ratioIfLimited(current float64, limit float64) float64 {
|
||||
if limit <= 0 {
|
||||
return 0
|
||||
}
|
||||
return current / limit
|
||||
}
|
||||
|
||||
func sortRuntimeModelCandidates(items []RuntimeModelCandidate) {
|
||||
hasFull := false
|
||||
hasNonFull := false
|
||||
for index := range items {
|
||||
items[index].LoadAvoided = false
|
||||
if runtimeCandidateFull(items[index]) {
|
||||
hasFull = true
|
||||
} else {
|
||||
hasNonFull = true
|
||||
}
|
||||
}
|
||||
if hasFull && hasNonFull {
|
||||
for index := range items {
|
||||
items[index].LoadAvoided = runtimeCandidateFull(items[index])
|
||||
}
|
||||
}
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
aFull := runtimeCandidateFull(items[i])
|
||||
bFull := runtimeCandidateFull(items[j])
|
||||
if aFull != bFull {
|
||||
return !aFull
|
||||
}
|
||||
if items[i].PlatformPriority != items[j].PlatformPriority {
|
||||
return items[i].PlatformPriority < items[j].PlatformPriority
|
||||
}
|
||||
if items[i].LoadRatio != items[j].LoadRatio {
|
||||
return items[i].LoadRatio < items[j].LoadRatio
|
||||
}
|
||||
if items[i].RunningCount != items[j].RunningCount {
|
||||
return items[i].RunningCount < items[j].RunningCount
|
||||
}
|
||||
if items[i].WaitingCount != items[j].WaitingCount {
|
||||
return items[i].WaitingCount < items[j].WaitingCount
|
||||
}
|
||||
if items[i].LastAssignedUnix != items[j].LastAssignedUnix {
|
||||
return items[i].LastAssignedUnix < items[j].LastAssignedUnix
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func runtimeCandidateFull(candidate RuntimeModelCandidate) bool {
|
||||
return candidate.LoadLimited && candidate.LoadRatio >= 1
|
||||
}
|
||||
|
||||
func (s *Store) modelCandidateCooldownError(ctx context.Context, model string, modelType string) (error, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT p.name,
|
||||
|
||||
61
apps/api/internal/store/candidates_test.go
Normal file
61
apps/api/internal/store/candidates_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package store
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRuntimeCandidateLoadUsesMaxLimitedMetric(t *testing.T) {
|
||||
candidate := RuntimeModelCandidate{}
|
||||
applyRuntimeCandidateLoad(&candidate, runtimeCandidateLoadInput{
|
||||
Policy: map[string]any{"rules": []any{
|
||||
map[string]any{"metric": "rpm", "limit": 100},
|
||||
map[string]any{"metric": "tpm_total", "limit": 1000},
|
||||
map[string]any{"metric": "concurrent", "limit": 10},
|
||||
}},
|
||||
RPMUsed: 40,
|
||||
RPMReserved: 10,
|
||||
TPMUsed: 900,
|
||||
ConcurrentActive: 3,
|
||||
QueuedWaiting: 2,
|
||||
})
|
||||
|
||||
if !candidate.LoadLimited {
|
||||
t.Fatal("expected load to be limited when rate limit rules exist")
|
||||
}
|
||||
if candidate.LoadMetrics.RPMRatio != 0.5 {
|
||||
t.Fatalf("expected rpm ratio 0.5, got %v", candidate.LoadMetrics.RPMRatio)
|
||||
}
|
||||
if candidate.LoadMetrics.TPMRatio != 0.9 {
|
||||
t.Fatalf("expected tpm ratio 0.9, got %v", candidate.LoadMetrics.TPMRatio)
|
||||
}
|
||||
if candidate.LoadMetrics.ConcurrentRatio != 0.5 {
|
||||
t.Fatalf("expected concurrent ratio 0.5, got %v", candidate.LoadMetrics.ConcurrentRatio)
|
||||
}
|
||||
if candidate.LoadRatio != 0.9 {
|
||||
t.Fatalf("expected max load ratio 0.9, got %v", candidate.LoadRatio)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeCandidateSortingAvoidsFullCandidatesButKeepsFallback(t *testing.T) {
|
||||
candidates := []RuntimeModelCandidate{
|
||||
{
|
||||
PlatformID: "high-priority-full",
|
||||
PlatformPriority: 1,
|
||||
LoadLimited: true,
|
||||
LoadRatio: 1.2,
|
||||
},
|
||||
{
|
||||
PlatformID: "lower-priority-available",
|
||||
PlatformPriority: 50,
|
||||
LoadLimited: true,
|
||||
LoadRatio: 0.2,
|
||||
},
|
||||
}
|
||||
|
||||
sortRuntimeModelCandidates(candidates)
|
||||
|
||||
if candidates[0].PlatformID != "lower-priority-available" {
|
||||
t.Fatalf("expected non-full candidate to be tried first, got %+v", candidates)
|
||||
}
|
||||
if candidates[1].PlatformID != "high-priority-full" || !candidates[1].LoadAvoided {
|
||||
t.Fatalf("expected full high-priority candidate to remain as avoided fallback, got %+v", candidates)
|
||||
}
|
||||
}
|
||||
@ -137,6 +137,29 @@ type RuntimeModelCandidate struct {
|
||||
DegradePolicy map[string]any
|
||||
ClientID string
|
||||
QueueKey string
|
||||
LoadRatio float64
|
||||
LoadLimited bool
|
||||
LoadAvoided bool
|
||||
LoadMetrics RuntimeCandidateLoadMetrics
|
||||
RunningCount float64
|
||||
WaitingCount float64
|
||||
LastAssignedUnix float64
|
||||
}
|
||||
|
||||
type RuntimeCandidateLoadMetrics struct {
|
||||
RPMCurrent float64
|
||||
RPMLimit float64
|
||||
RPMRatio float64
|
||||
TPMCurrent float64
|
||||
TPMLimit float64
|
||||
TPMRatio float64
|
||||
ConcurrentCurrent float64
|
||||
ConcurrentLimit float64
|
||||
ConcurrentRatio float64
|
||||
QueuedCount float64
|
||||
StateRunningCount float64
|
||||
StateWaitingCount float64
|
||||
StateLimiterRatio float64
|
||||
}
|
||||
|
||||
type RateLimitReservation struct {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user