easyai-ai-gateway/apps/api/internal/store/rate_limits.go

362 lines
11 KiB
Go

package store
import (
"context"
"errors"
"time"
"github.com/jackc/pgx/v5"
)
type RuntimeRecoveryResult struct {
ReleasedConcurrencyLeases int64 `json:"releasedConcurrencyLeases"`
ReleasedRateReservations int64 `json:"releasedRateReservations"`
FailedAttempts int64 `json:"failedAttempts"`
FailedTasks int64 `json:"failedTasks"`
}
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return RateLimitResult{}, err
}
defer tx.Rollback(ctx)
result := RateLimitResult{}
for _, reservation := range reservations {
if reservation.Limit <= 0 || reservation.Amount <= 0 {
continue
}
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
return RateLimitResult{}, ErrRateLimited
}
if reservation.WindowSeconds <= 0 {
reservation.WindowSeconds = 60
}
if reservation.Metric == "concurrent" {
leaseID, err := reserveConcurrencyLease(ctx, tx, taskID, attemptID, reservation)
if err != nil {
return RateLimitResult{}, err
}
result.LeaseIDs = append(result.LeaseIDs, leaseID)
continue
}
normalized, err := reserveCounterWindow(ctx, tx, taskID, attemptID, reservation)
if err != nil {
return RateLimitResult{}, err
}
result.Reservations = append(result.Reservations, normalized)
}
return result, tx.Commit(ctx)
}
func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, attemptID string, reservation RateLimitReservation) (string, error) {
if reservation.LeaseTTLSeconds <= 0 {
reservation.LeaseTTLSeconds = 120
}
var active float64
if err := tx.QueryRow(ctx, `
SELECT COALESCE(SUM(lease_value), 0)::float8
FROM gateway_concurrency_leases
WHERE scope_type = $1
AND scope_key = $2
AND released_at IS NULL
AND expires_at > now()`,
reservation.ScopeType,
reservation.ScopeKey,
).Scan(&active); err != nil {
return "", err
}
if active+reservation.Amount > reservation.Limit {
return "", ErrRateLimited
}
var leaseID string
if err := tx.QueryRow(ctx, `
INSERT INTO gateway_concurrency_leases (task_id, attempt_id, scope_type, scope_key, lease_value, expires_at)
VALUES ($1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, now() + ($6::int * interval '1 second'))
RETURNING id::text`,
taskID,
attemptID,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Amount,
reservation.LeaseTTLSeconds,
).Scan(&leaseID); err != nil {
return "", err
}
return leaseID, nil
}
func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attemptID string, reservation RateLimitReservation) (RateLimitReservation, error) {
usedAmount := 0.0
reservedAmount := reservation.Amount
var windowStart time.Time
err := tx.QueryRow(ctx, `
INSERT INTO gateway_rate_limit_counters (
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
)
VALUES (
$1, $2, $3, date_trunc('minute', now()), $4, $5, $6,
date_trunc('minute', now()) + ($7::int * interval '1 second')
)
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
SET limit_value = EXCLUDED.limit_value,
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
reserved_value = gateway_rate_limit_counters.reserved_value + EXCLUDED.reserved_value,
reset_at = EXCLUDED.reset_at,
updated_at = now()
WHERE gateway_rate_limit_counters.used_value + gateway_rate_limit_counters.reserved_value + EXCLUDED.used_value + EXCLUDED.reserved_value <= EXCLUDED.limit_value
RETURNING window_start`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
reservation.Limit,
usedAmount,
reservedAmount,
reservation.WindowSeconds,
).Scan(&windowStart)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return RateLimitReservation{}, ErrRateLimited
}
return RateLimitReservation{}, err
}
reservation.WindowStart = windowStart
if err := tx.QueryRow(ctx, `
INSERT INTO gateway_rate_limit_reservations (
task_id, attempt_id, scope_type, scope_key, metric, window_start, limit_value, reserved_amount, status
)
VALUES (
$1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6, $7, $8, 'reserved'
)
RETURNING id::text`,
taskID,
attemptID,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
windowStart,
reservation.Limit,
reservedAmount,
).Scan(&reservation.ReservationID); err != nil {
return RateLimitReservation{}, err
}
return reservation, nil
}
func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error {
return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success")
}
func (s *Store) ReleaseRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, reason string) error {
return s.finishRateLimitReservations(ctx, reservations, nil, "released", reason)
}
func (s *Store) ReleaseConcurrencyLeases(ctx context.Context, leaseIDs []string) error {
if len(leaseIDs) == 0 {
return nil
}
for _, leaseID := range leaseIDs {
if leaseID == "" {
continue
}
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_concurrency_leases
SET released_at = now()
WHERE id = $1::uuid AND released_at IS NULL`, leaseID); err != nil && !errors.Is(err, ErrRateLimited) {
return err
}
}
return nil
}
func (s *Store) RecoverInterruptedRuntimeState(ctx context.Context) (RuntimeRecoveryResult, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return RuntimeRecoveryResult{}, err
}
defer tx.Rollback(ctx)
result := RuntimeRecoveryResult{}
rows, err := tx.Query(ctx, `
UPDATE gateway_rate_limit_reservations
SET status = 'released',
reason = 'server_restarted',
finalized_at = now(),
updated_at = now()
WHERE status = 'reserved'
RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
for rows.Next() {
var reservation RateLimitReservation
if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
result.ReleasedRateReservations++
}
if err := rows.Err(); err != nil {
rows.Close()
return RuntimeRecoveryResult{}, err
}
rows.Close()
tag, err := tx.Exec(ctx, `
UPDATE gateway_concurrency_leases
SET released_at = now()
WHERE released_at IS NULL
AND expires_at > now()`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
result.ReleasedConcurrencyLeases = tag.RowsAffected()
tag, err = tx.Exec(ctx, `
UPDATE gateway_task_attempts
SET status = 'failed',
retryable = false,
error_code = 'server_restarted',
error_message = 'attempt interrupted by service restart',
finished_at = now()
WHERE status = 'running'`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
result.FailedAttempts = tag.RowsAffected()
taskRows, err := tx.Query(ctx, `
UPDATE gateway_tasks
SET status = 'failed',
error = 'task interrupted by service restart',
error_code = 'server_restarted',
error_message = 'task interrupted by service restart',
finished_at = now(),
updated_at = now()
WHERE status IN ('queued', 'running')
RETURNING id::text`)
if err != nil {
return RuntimeRecoveryResult{}, err
}
taskIDs := make([]string, 0)
for taskRows.Next() {
var taskID string
if err := taskRows.Scan(&taskID); err != nil {
taskRows.Close()
return RuntimeRecoveryResult{}, err
}
taskIDs = append(taskIDs, taskID)
}
if err := taskRows.Err(); err != nil {
taskRows.Close()
return RuntimeRecoveryResult{}, err
}
taskRows.Close()
for _, taskID := range taskIDs {
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated)
VALUES (
$1::uuid,
COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1),
'task.recovered',
'failed',
'recovered',
1,
'task interrupted by service restart',
'{"code":"server_restarted"}'::jsonb,
false
)`, taskID); err != nil {
return RuntimeRecoveryResult{}, err
}
}
result.FailedTasks = int64(len(taskIDs))
return result, tx.Commit(ctx)
}
func (s *Store) finishRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64, status string, reason string) error {
if len(reservations) == 0 {
return nil
}
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
for _, reservation := range reservations {
if reservation.ReservationID == "" {
continue
}
actualAmount := actualByMetric[reservation.Metric]
if status == "committed" && actualAmount <= 0 {
actualAmount = reservation.Amount
}
var stored RateLimitReservation
err := tx.QueryRow(ctx, `
UPDATE gateway_rate_limit_reservations
SET status = $2,
reason = NULLIF($3, ''),
actual_amount = CASE WHEN $2 = 'committed' THEN $4 ELSE actual_amount END,
finalized_at = now(),
updated_at = now()
WHERE id = $1::uuid
AND status = 'reserved'
RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`,
reservation.ReservationID,
status,
reason,
actualAmount,
).Scan(&stored.ScopeType, &stored.ScopeKey, &stored.Metric, &stored.WindowStart, &stored.Amount)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
continue
}
return err
}
if status == "committed" {
if err := commitCounterReservation(ctx, tx, stored.ScopeType, stored.ScopeKey, stored.Metric, stored.WindowStart, stored.Amount, actualAmount); err != nil {
return err
}
continue
}
if err := releaseCounterReservation(ctx, tx, stored.ScopeType, stored.ScopeKey, stored.Metric, stored.WindowStart, stored.Amount); err != nil {
return err
}
}
return tx.Commit(ctx)
}
func commitCounterReservation(ctx context.Context, tx pgx.Tx, scopeType string, scopeKey string, metric string, windowStart time.Time, reservedAmount float64, actualAmount float64) error {
_, err := tx.Exec(ctx, `
UPDATE gateway_rate_limit_counters
SET reserved_value = GREATEST(reserved_value - $5, 0),
used_value = used_value + $6,
updated_at = now()
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3
AND window_start = $4`,
scopeType, scopeKey, metric, windowStart, reservedAmount, actualAmount)
return err
}
func releaseCounterReservation(ctx context.Context, tx pgx.Tx, scopeType string, scopeKey string, metric string, windowStart time.Time, reservedAmount float64) error {
_, err := tx.Exec(ctx, `
UPDATE gateway_rate_limit_counters
SET reserved_value = GREATEST(reserved_value - $5, 0),
updated_at = now()
WHERE scope_type = $1
AND scope_key = $2
AND metric = $3
AND window_start = $4`,
scopeType, scopeKey, metric, windowStart, reservedAmount)
return err
}