easyai-ai-gateway/apps/api/internal/store/rate_limits.go

477 lines
15 KiB
Go

package store
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
)
type RuntimeRecoveryResult struct {
ReleasedConcurrencyLeases int64 `json:"releasedConcurrencyLeases"`
ReleasedRateReservations int64 `json:"releasedRateReservations"`
FailedAttempts int64 `json:"failedAttempts"`
FailedTasks int64 `json:"failedTasks"`
RequeuedAsyncTasks int64 `json:"requeuedAsyncTasks"`
}
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return RateLimitResult{}, err
}
defer tx.Rollback(ctx)
result := RateLimitResult{}
for _, reservation := range reservations {
if reservation.Limit <= 0 || reservation.Amount <= 0 {
continue
}
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
return RateLimitResult{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit),
Retryable: false,
}
}
if reservation.WindowSeconds <= 0 {
reservation.WindowSeconds = 60
}
if reservation.Metric == "concurrent" {
leaseID, err := reserveConcurrencyLease(ctx, tx, taskID, attemptID, reservation)
if err != nil {
return RateLimitResult{}, err
}
result.LeaseIDs = append(result.LeaseIDs, leaseID)
continue
}
normalized, err := reserveCounterWindow(ctx, tx, taskID, attemptID, reservation)
if err != nil {
return RateLimitResult{}, err
}
result.Reservations = append(result.Reservations, normalized)
}
return result, tx.Commit(ctx)
}
func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, attemptID string, reservation RateLimitReservation) (string, error) {
if reservation.LeaseTTLSeconds <= 0 {
reservation.LeaseTTLSeconds = 120
}
var active float64
var nextAvailableAt time.Time
if err := tx.QueryRow(ctx, `
SELECT COALESCE(SUM(lease_value), 0)::float8,
COALESCE(MIN(expires_at), now() + ($3::int * interval '1 second'))
FROM gateway_concurrency_leases
WHERE scope_type = $1
AND scope_key = $2
AND released_at IS NULL
AND expires_at > now()`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.LeaseTTLSeconds,
).Scan(&active, &nextAvailableAt); err != nil {
return "", err
}
if active+reservation.Amount > reservation.Limit {
return "", &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit),
RetryAfter: concurrencyRetryAfter(nextAvailableAt),
Retryable: true,
}
}
var leaseID string
if err := tx.QueryRow(ctx, `
INSERT INTO gateway_concurrency_leases (task_id, attempt_id, scope_type, scope_key, lease_value, expires_at)
VALUES ($1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, now() + ($6::int * interval '1 second'))
RETURNING id::text`,
taskID,
attemptID,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Amount,
reservation.LeaseTTLSeconds,
).Scan(&leaseID); err != nil {
return "", err
}
return leaseID, nil
}
func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attemptID string, reservation RateLimitReservation) (RateLimitReservation, error) {
usedAmount := 0.0
reservedAmount := reservation.Amount
var windowStart time.Time
err := tx.QueryRow(ctx, `
WITH bounds AS (
SELECT
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) AS window_start,
to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) + ($7::int * interval '1 second') AS reset_at
)
INSERT INTO gateway_rate_limit_counters (
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
)
SELECT $1, $2, $3, bounds.window_start, $4, $5, $6, bounds.reset_at
FROM bounds
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
SET limit_value = EXCLUDED.limit_value,
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
reserved_value = gateway_rate_limit_counters.reserved_value + EXCLUDED.reserved_value,
reset_at = EXCLUDED.reset_at,
updated_at = now()
WHERE gateway_rate_limit_counters.used_value + gateway_rate_limit_counters.reserved_value + EXCLUDED.used_value + EXCLUDED.reserved_value <= EXCLUDED.limit_value
RETURNING window_start`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
reservation.Limit,
usedAmount,
reservedAmount,
reservation.WindowSeconds,
).Scan(&windowStart)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second)
_ = tx.QueryRow(ctx, `
WITH bounds AS (
SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start
)
SELECT counters.reset_at
FROM gateway_rate_limit_counters counters
JOIN bounds ON counters.window_start = bounds.window_start
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
reservation.WindowSeconds,
).Scan(&resetAt)
return RateLimitReservation{}, &RateLimitExceededError{
Metric: reservation.Metric,
Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric),
RetryAfter: retryAfterUntil(resetAt),
Retryable: true,
}
}
return RateLimitReservation{}, err
}
reservation.WindowStart = windowStart
if err := tx.QueryRow(ctx, `
INSERT INTO gateway_rate_limit_reservations (
task_id, attempt_id, scope_type, scope_key, metric, window_start, limit_value, reserved_amount, status
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6, $7, $8, 'reserved'
)
RETURNING id::text`,
taskID,
attemptID,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
windowStart,
reservation.Limit,
reservedAmount,
).Scan(&reservation.ReservationID); err != nil {
return RateLimitReservation{}, err
}
return reservation, nil
}
func retryAfterUntil(when time.Time) time.Duration {
if when.IsZero() {
return 0
}
duration := time.Until(when)
if duration < time.Second {
return time.Second
}
return duration
}
func concurrencyRetryAfter(leaseExpiresAt time.Time) time.Duration {
if leaseExpiresAt.IsZero() {
return time.Second
}
duration := time.Until(leaseExpiresAt)
if duration <= time.Second {
return time.Second
}
return time.Second
}
func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error {
return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success")
}
func (s *Store) ReleaseRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, reason string) error {
return s.finishRateLimitReservations(ctx, reservations, nil, "released", reason)
}
func (s *Store) ReleaseConcurrencyLeases(ctx context.Context, leaseIDs []string) error {
if len(leaseIDs) == 0 {
return nil
}
for _, leaseID := range leaseIDs {
if leaseID == "" {
continue
}
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_concurrency_leases
SET released_at = now()
WHERE id = $1::uuid AND released_at IS NULL`, leaseID); err != nil && !errors.Is(err, ErrRateLimited) {
return err
}
}
return nil
}
func (s *Store) RecoverInterruptedRuntimeState(ctx context.Context) (RuntimeRecoveryResult, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return RuntimeRecoveryResult{}, err
}
defer tx.Rollback(ctx)
result := RuntimeRecoveryResult{}
rows, err := tx.Query(ctx, `
UPDATE gateway_rate_limit_reservations
SET status = 'released',
reason = 'server_restarted',
finalized_at = now(),
updated_at = now()
WHERE status = 'reserved'
RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
reservations := make([]RateLimitReservation, 0)
for rows.Next() {
var reservation RateLimitReservation
if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
reservations = append(reservations, reservation)
}
if err := rows.Err(); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
rows.Close()
for _, reservation := range reservations {
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.ReleasedRateReservations = int64(len(reservations))
tag, err := tx.Exec(ctx, `
UPDATE gateway_concurrency_leases
SET released_at = now()
WHERE released_at IS NULL
AND expires_at > now()`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
result.ReleasedConcurrencyLeases = tag.RowsAffected()
tag, err = tx.Exec(ctx, `
UPDATE gateway_task_attempts
SET status = 'failed',
retryable = true,
error_code = 'server_restarted',
error_message = 'attempt interrupted by service restart',
finished_at = now()
WHERE status = 'running'`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
result.FailedAttempts = tag.RowsAffected()
asyncTaskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks
SET status = 'queued',
error = NULL,
error_code = NULL,
error_message = NULL,
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
next_run_at = now(),
finished_at = NULL,
updated_at = now()
WHERE async_mode = true
AND status = 'running'
RETURNING id::text`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
asyncTaskIDs := make([]string, 0)
for asyncTaskRows.Next() {
var taskID string
if err := asyncTaskRows.Scan(&taskID); err != nil {
asyncTaskRows.Close()
return RuntimeRecoveryResult{}, err
}
asyncTaskIDs = append(asyncTaskIDs, taskID)
}
if err := asyncTaskRows.Err(); err != nil {
asyncTaskRows.Close()
return RuntimeRecoveryResult{}, err
}
asyncTaskRows.Close()
for _, taskID := range asyncTaskIDs {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES (
$1::uuid,
COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1),
'task.recovered',
'queued',
'recovered',
0.2,
'async task recovered after service restart',
'{"code":"server_restarted"}'::jsonb,
false
)`, taskID); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.RequeuedAsyncTasks = int64(len(asyncTaskIDs))
taskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks
SET status = 'failed',
error = 'task interrupted by service restart',
error_code = 'server_restarted',
error_message = 'task interrupted by service restart',
finished_at = now(),
updated_at = now()
WHERE async_mode = false
AND status = 'running'
RETURNING id::text`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
taskIDs := make([]string, 0)
for taskRows.Next() {
var taskID string
if err := taskRows.Scan(&taskID); err != nil {
taskRows.Close()
return RuntimeRecoveryResult{}, err
}
taskIDs = append(taskIDs, taskID)
}
if err := taskRows.Err(); err != nil {
taskRows.Close()
return RuntimeRecoveryResult{}, err
}
taskRows.Close()
for _, taskID := range taskIDs {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES (
$1::uuid,
COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1),
'task.recovered',
'failed',
'recovered',
1,
'task interrupted by service restart',
'{"code":"server_restarted"}'::jsonb,
false
)`, taskID); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.FailedTasks = int64(len(taskIDs))
return result, tx.Commit(ctx)
}
func (s *Store) finishRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64, status string, reason string) error {
if len(reservations) == 0 {
return nil
}
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
for _, reservation := range reservations {
if reservation.ReservationID == "" {
continue
}
actualAmount := actualByMetric[reservation.Metric]
if status == "committed" && actualAmount <= 0 {
actualAmount = reservation.Amount
}
var stored RateLimitReservation
err := tx.QueryRow(ctx, `
UPDATE gateway_rate_limit_reservations
SET status = $2::text,
reason = NULLIF($3::text, ''),
actual_amount = CASE WHEN $2::text = 'committed' THEN $4 ELSE actual_amount END,
finalized_at = now(),
updated_at = now()
WHERE id = $1::uuid
AND status = 'reserved'
RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`,
reservation.ReservationID,
status,
reason,
actualAmount,
).Scan(&stored.ScopeType, &stored.ScopeKey, &stored.Metric, &stored.WindowStart, &stored.Amount)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
continue
}
return err
}
if status == "committed" {
if err := commitCounterReservation(ctx, tx, stored.ScopeType, stored.ScopeKey, stored.Metric, stored.WindowStart, stored.Amount, actualAmount); err != nil {
return err
}
continue
}
if err := releaseCounterReservation(ctx, tx, stored.ScopeType, stored.ScopeKey, stored.Metric, stored.WindowStart, stored.Amount); err != nil {
return err
}
}
return tx.Commit(ctx)
}
func commitCounterReservation(ctx context.Context, tx pgx.Tx, scopeType string, scopeKey string, metric string, windowStart time.Time, reservedAmount float64, actualAmount float64) error {
_, err := tx.Exec(ctx, `
UPDATE gateway_rate_limit_counters
SET reserved_value = GREATEST(reserved_value - $5, 0),
used_value = used_value + $6,
updated_at = now()
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3
AND window_start = $4`,
scopeType, scopeKey, metric, windowStart, reservedAmount, actualAmount)
return err
}
func releaseCounterReservation(ctx context.Context, tx pgx.Tx, scopeType string, scopeKey string, metric string, windowStart time.Time, reservedAmount float64) error {
_, err := tx.Exec(ctx, `
UPDATE gateway_rate_limit_counters
SET reserved_value = GREATEST(reserved_value - $5, 0),
updated_at = now()
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3
AND window_start = $4`,
scopeType, scopeKey, metric, windowStart, reservedAmount)
return err
}