package store import ( "context" "errors" "fmt" "time" "github.com/jackc/pgx/v5" ) type RuntimeRecoveryResult struct { ReleasedConcurrencyLeases int64 `json:"releasedConcurrencyLeases"` ReleasedRateReservations int64 `json:"releasedRateReservations"` FailedAttempts int64 `json:"failedAttempts"` FailedTasks int64 `json:"failedTasks"` RequeuedAsyncTasks int64 `json:"requeuedAsyncTasks"` } func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, attemptID string, reservations []RateLimitReservation) (RateLimitResult, error) { tx, err := s.pool.Begin(ctx) if err != nil { return RateLimitResult{}, err } defer tx.Rollback(ctx) result := RateLimitResult{} for _, reservation := range reservations { if reservation.Limit <= 0 || reservation.Amount <= 0 { continue } if reservation.Metric == "" || reservation.Amount > reservation.Limit { return RateLimitResult{}, &RateLimitExceededError{ Metric: reservation.Metric, Message: fmt.Sprintf("rate limit exceeded: %s request amount %.0f is greater than limit %.0f", reservation.Metric, reservation.Amount, reservation.Limit), Retryable: false, } } if reservation.WindowSeconds <= 0 { reservation.WindowSeconds = 60 } if reservation.Metric == "concurrent" { leaseID, err := reserveConcurrencyLease(ctx, tx, taskID, attemptID, reservation) if err != nil { return RateLimitResult{}, err } result.LeaseIDs = append(result.LeaseIDs, leaseID) continue } normalized, err := reserveCounterWindow(ctx, tx, taskID, attemptID, reservation) if err != nil { return RateLimitResult{}, err } result.Reservations = append(result.Reservations, normalized) } return result, tx.Commit(ctx) } func reserveConcurrencyLease(ctx context.Context, tx pgx.Tx, taskID string, attemptID string, reservation RateLimitReservation) (string, error) { if reservation.LeaseTTLSeconds <= 0 { reservation.LeaseTTLSeconds = 120 } var active float64 var nextAvailableAt time.Time if err := tx.QueryRow(ctx, ` SELECT COALESCE(SUM(lease_value), 0)::float8, COALESCE(MIN(expires_at), now() + ($3::int * interval '1 second')) FROM gateway_concurrency_leases WHERE scope_type = $1 AND scope_key = $2 AND released_at IS NULL AND expires_at > now()`, reservation.ScopeType, reservation.ScopeKey, reservation.LeaseTTLSeconds, ).Scan(&active, &nextAvailableAt); err != nil { return "", err } if active+reservation.Amount > reservation.Limit { return "", &RateLimitExceededError{ Metric: reservation.Metric, Message: fmt.Sprintf("rate limit exceeded: concurrent active %.0f plus request %.0f is greater than limit %.0f", active, reservation.Amount, reservation.Limit), RetryAfter: concurrencyRetryAfter(nextAvailableAt), Retryable: true, } } var leaseID string if err := tx.QueryRow(ctx, ` INSERT INTO gateway_concurrency_leases (task_id, attempt_id, scope_type, scope_key, lease_value, expires_at) VALUES ($1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, now() + ($6::int * interval '1 second')) RETURNING id::text`, taskID, attemptID, reservation.ScopeType, reservation.ScopeKey, reservation.Amount, reservation.LeaseTTLSeconds, ).Scan(&leaseID); err != nil { return "", err } return leaseID, nil } func reserveCounterWindow(ctx context.Context, tx pgx.Tx, taskID string, attemptID string, reservation RateLimitReservation) (RateLimitReservation, error) { usedAmount := 0.0 reservedAmount := reservation.Amount var windowStart time.Time err := tx.QueryRow(ctx, ` WITH bounds AS ( SELECT to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) AS window_start, to_timestamp(floor(extract(epoch FROM now()) / $7::int) * $7::int) + ($7::int * interval '1 second') AS reset_at ) INSERT INTO gateway_rate_limit_counters ( scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at ) SELECT $1, $2, $3, bounds.window_start, $4, $5, $6, bounds.reset_at FROM bounds ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE SET limit_value = EXCLUDED.limit_value, used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value, reserved_value = gateway_rate_limit_counters.reserved_value + EXCLUDED.reserved_value, reset_at = EXCLUDED.reset_at, updated_at = now() WHERE gateway_rate_limit_counters.used_value + gateway_rate_limit_counters.reserved_value + EXCLUDED.used_value + EXCLUDED.reserved_value <= EXCLUDED.limit_value RETURNING window_start`, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.Limit, usedAmount, reservedAmount, reservation.WindowSeconds, ).Scan(&windowStart) if err != nil { if errors.Is(err, pgx.ErrNoRows) { resetAt := time.Now().Add(time.Duration(reservation.WindowSeconds) * time.Second) _ = tx.QueryRow(ctx, ` WITH bounds AS ( SELECT to_timestamp(floor(extract(epoch FROM now()) / $4::int) * $4::int) AS window_start ) SELECT counters.reset_at FROM gateway_rate_limit_counters counters JOIN bounds ON counters.window_start = bounds.window_start WHERE scope_type = $1 AND scope_key = $2 AND metric = $3`, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowSeconds, ).Scan(&resetAt) return RateLimitReservation{}, &RateLimitExceededError{ Metric: reservation.Metric, Message: fmt.Sprintf("rate limit exceeded: %s window has no remaining capacity", reservation.Metric), RetryAfter: retryAfterUntil(resetAt), Retryable: true, } } return RateLimitReservation{}, err } reservation.WindowStart = windowStart if err := tx.QueryRow(ctx, ` INSERT INTO gateway_rate_limit_reservations ( task_id, attempt_id, scope_type, scope_key, metric, window_start, limit_value, reserved_amount, status ) VALUES ( $1::uuid, NULLIF($2, '')::uuid, $3, $4, $5, $6, $7, $8, 'reserved' ) RETURNING id::text`, taskID, attemptID, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, windowStart, reservation.Limit, reservedAmount, ).Scan(&reservation.ReservationID); err != nil { return RateLimitReservation{}, err } return reservation, nil } func retryAfterUntil(when time.Time) time.Duration { if when.IsZero() { return 0 } duration := time.Until(when) if duration < time.Second { return time.Second } return duration } func concurrencyRetryAfter(leaseExpiresAt time.Time) time.Duration { if leaseExpiresAt.IsZero() { return time.Second } duration := time.Until(leaseExpiresAt) if duration <= time.Second { return time.Second } return time.Second } func (s *Store) CommitRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64) error { return s.finishRateLimitReservations(ctx, reservations, actualByMetric, "committed", "success") } func (s *Store) ReleaseRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, reason string) error { return s.finishRateLimitReservations(ctx, reservations, nil, "released", reason) } func (s *Store) ReleaseConcurrencyLeases(ctx context.Context, leaseIDs []string) error { if len(leaseIDs) == 0 { return nil } for _, leaseID := range leaseIDs { if leaseID == "" { continue } if _, err := s.pool.Exec(ctx, ` UPDATE gateway_concurrency_leases SET released_at = now() WHERE id = $1::uuid AND released_at IS NULL`, leaseID); err != nil && !errors.Is(err, ErrRateLimited) { return err } } return nil } func (s *Store) RecoverInterruptedRuntimeState(ctx context.Context) (RuntimeRecoveryResult, error) { tx, err := s.pool.Begin(ctx) if err != nil { return RuntimeRecoveryResult{}, err } defer tx.Rollback(ctx) result := RuntimeRecoveryResult{} rows, err := tx.Query(ctx, ` UPDATE gateway_rate_limit_reservations SET status = 'released', reason = 'server_restarted', finalized_at = now(), updated_at = now() WHERE status = 'reserved' RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`) if err != nil { return RuntimeRecoveryResult{}, err } reservations := make([]RateLimitReservation, 0) for rows.Next() { var reservation RateLimitReservation if err := rows.Scan(&reservation.ScopeType, &reservation.ScopeKey, &reservation.Metric, &reservation.WindowStart, &reservation.Amount); err != nil { rows.Close() return RuntimeRecoveryResult{}, err } reservations = append(reservations, reservation) } if err := rows.Err(); err != nil { rows.Close() return RuntimeRecoveryResult{}, err } rows.Close() for _, reservation := range reservations { if err := releaseCounterReservation(ctx, tx, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.WindowStart, reservation.Amount); err != nil { return RuntimeRecoveryResult{}, err } } result.ReleasedRateReservations = int64(len(reservations)) tag, err := tx.Exec(ctx, ` UPDATE gateway_concurrency_leases SET released_at = now() WHERE released_at IS NULL AND expires_at > now()`) if err != nil { return RuntimeRecoveryResult{}, err } result.ReleasedConcurrencyLeases = tag.RowsAffected() tag, err = tx.Exec(ctx, ` UPDATE gateway_task_attempts SET status = 'failed', retryable = true, error_code = 'server_restarted', error_message = 'attempt interrupted by service restart', finished_at = now() WHERE status = 'running'`) if err != nil { return RuntimeRecoveryResult{}, err } result.FailedAttempts = tag.RowsAffected() asyncTaskRows, err := tx.Query(ctx, ` UPDATE gateway_tasks SET status = 'queued', error = NULL, error_code = NULL, error_message = NULL, locked_by = NULL, locked_at = NULL, heartbeat_at = NULL, next_run_at = now(), finished_at = NULL, updated_at = now() WHERE async_mode = true AND status = 'running' RETURNING id::text`) if err != nil { return RuntimeRecoveryResult{}, err } asyncTaskIDs := make([]string, 0) for asyncTaskRows.Next() { var taskID string if err := asyncTaskRows.Scan(&taskID); err != nil { asyncTaskRows.Close() return RuntimeRecoveryResult{}, err } asyncTaskIDs = append(asyncTaskIDs, taskID) } if err := asyncTaskRows.Err(); err != nil { asyncTaskRows.Close() return RuntimeRecoveryResult{}, err } asyncTaskRows.Close() for _, taskID := range asyncTaskIDs { if _, err := tx.Exec(ctx, ` INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) VALUES ( $1::uuid, COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1), 'task.recovered', 'queued', 'recovered', 0.2, 'async task recovered after service restart', '{"code":"server_restarted"}'::jsonb, false )`, taskID); err != nil { return RuntimeRecoveryResult{}, err } } result.RequeuedAsyncTasks = int64(len(asyncTaskIDs)) taskRows, err := tx.Query(ctx, ` UPDATE gateway_tasks SET status = 'failed', error = 'task interrupted by service restart', error_code = 'server_restarted', error_message = 'task interrupted by service restart', finished_at = now(), updated_at = now() WHERE async_mode = false AND status = 'running' RETURNING id::text`) if err != nil { return RuntimeRecoveryResult{}, err } taskIDs := make([]string, 0) for taskRows.Next() { var taskID string if err := taskRows.Scan(&taskID); err != nil { taskRows.Close() return RuntimeRecoveryResult{}, err } taskIDs = append(taskIDs, taskID) } if err := taskRows.Err(); err != nil { taskRows.Close() return RuntimeRecoveryResult{}, err } taskRows.Close() for _, taskID := range taskIDs { if _, err := tx.Exec(ctx, ` INSERT INTO gateway_task_events (task_id, seq, event_type, status, phase, progress, message, payload, simulated) VALUES ( $1::uuid, COALESCE((SELECT MAX(seq) + 1 FROM gateway_task_events WHERE task_id = $1::uuid), 1), 'task.recovered', 'failed', 'recovered', 1, 'task interrupted by service restart', '{"code":"server_restarted"}'::jsonb, false )`, taskID); err != nil { return RuntimeRecoveryResult{}, err } } result.FailedTasks = int64(len(taskIDs)) return result, tx.Commit(ctx) } func (s *Store) finishRateLimitReservations(ctx context.Context, reservations []RateLimitReservation, actualByMetric map[string]float64, status string, reason string) error { if len(reservations) == 0 { return nil } tx, err := s.pool.Begin(ctx) if err != nil { return err } defer tx.Rollback(ctx) for _, reservation := range reservations { if reservation.ReservationID == "" { continue } actualAmount := actualByMetric[reservation.Metric] if status == "committed" && actualAmount <= 0 { actualAmount = reservation.Amount } var stored RateLimitReservation err := tx.QueryRow(ctx, ` UPDATE gateway_rate_limit_reservations SET status = $2::text, reason = NULLIF($3::text, ''), actual_amount = CASE WHEN $2::text = 'committed' THEN $4 ELSE actual_amount END, finalized_at = now(), updated_at = now() WHERE id = $1::uuid AND status = 'reserved' RETURNING scope_type, scope_key, metric, window_start, reserved_amount::float8`, reservation.ReservationID, status, reason, actualAmount, ).Scan(&stored.ScopeType, &stored.ScopeKey, &stored.Metric, &stored.WindowStart, &stored.Amount) if err != nil { if errors.Is(err, pgx.ErrNoRows) { continue } return err } if status == "committed" { if err := commitCounterReservation(ctx, tx, stored.ScopeType, stored.ScopeKey, stored.Metric, stored.WindowStart, stored.Amount, actualAmount); err != nil { return err } continue } if err := releaseCounterReservation(ctx, tx, stored.ScopeType, stored.ScopeKey, stored.Metric, stored.WindowStart, stored.Amount); err != nil { return err } } return tx.Commit(ctx) } func commitCounterReservation(ctx context.Context, tx pgx.Tx, scopeType string, scopeKey string, metric string, windowStart time.Time, reservedAmount float64, actualAmount float64) error { _, err := tx.Exec(ctx, ` UPDATE gateway_rate_limit_counters SET reserved_value = GREATEST(reserved_value - $5, 0), used_value = used_value + $6, updated_at = now() WHERE scope_type = $1 AND scope_key = $2 AND metric = $3 AND window_start = $4`, scopeType, scopeKey, metric, windowStart, reservedAmount, actualAmount) return err } func releaseCounterReservation(ctx context.Context, tx pgx.Tx, scopeType string, scopeKey string, metric string, windowStart time.Time, reservedAmount float64) error { _, err := tx.Exec(ctx, ` UPDATE gateway_rate_limit_counters SET reserved_value = GREATEST(reserved_value - $5, 0), updated_at = now() WHERE scope_type = $1 AND scope_key = $2 AND metric = $3 AND window_start = $4`, scopeType, scopeKey, metric, windowStart, reservedAmount) return err }