package store import ( "context" "errors" ) func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, reservations []RateLimitReservation) (RateLimitResult, error) { tx, err := s.pool.Begin(ctx) if err != nil { return RateLimitResult{}, err } defer tx.Rollback(ctx) result := RateLimitResult{} for _, reservation := range reservations { if reservation.Limit <= 0 || reservation.Amount <= 0 { continue } if reservation.Metric == "" || reservation.Amount > reservation.Limit { return RateLimitResult{}, ErrRateLimited } if reservation.WindowSeconds <= 0 { reservation.WindowSeconds = 60 } if reservation.Metric == "concurrent" { if reservation.LeaseTTLSeconds <= 0 { reservation.LeaseTTLSeconds = 120 } var active float64 if err := tx.QueryRow(ctx, ` SELECT COALESCE(SUM(lease_value), 0)::float8 FROM gateway_concurrency_leases WHERE scope_type = $1 AND scope_key = $2 AND released_at IS NULL AND expires_at > now()`, reservation.ScopeType, reservation.ScopeKey, ).Scan(&active); err != nil { return RateLimitResult{}, err } if active+reservation.Amount > reservation.Limit { return RateLimitResult{}, ErrRateLimited } var leaseID string if err := tx.QueryRow(ctx, ` INSERT INTO gateway_concurrency_leases (task_id, scope_type, scope_key, lease_value, expires_at) VALUES ($1::uuid, $2, $3, $4, now() + ($5::int * interval '1 second')) RETURNING id::text`, taskID, reservation.ScopeType, reservation.ScopeKey, reservation.Amount, reservation.LeaseTTLSeconds, ).Scan(&leaseID); err != nil { return RateLimitResult{}, err } result.LeaseIDs = append(result.LeaseIDs, leaseID) continue } tag, err := tx.Exec(ctx, ` INSERT INTO gateway_rate_limit_counters ( scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at ) VALUES ( $1, $2, $3, date_trunc('minute', now()), $4, $5, 0, date_trunc('minute', now()) + ($6::int * interval '1 second') ) ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE SET limit_value = EXCLUDED.limit_value, used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value, reset_at = EXCLUDED.reset_at, updated_at = now() WHERE gateway_rate_limit_counters.used_value + EXCLUDED.used_value <= EXCLUDED.limit_value`, reservation.ScopeType, reservation.ScopeKey, reservation.Metric, reservation.Limit, reservation.Amount, reservation.WindowSeconds, ) if err != nil { return RateLimitResult{}, err } if tag.RowsAffected() == 0 { return RateLimitResult{}, ErrRateLimited } } return result, tx.Commit(ctx) } func (s *Store) ReleaseConcurrencyLeases(ctx context.Context, leaseIDs []string) error { if len(leaseIDs) == 0 { return nil } for _, leaseID := range leaseIDs { if leaseID == "" { continue } if _, err := s.pool.Exec(ctx, ` UPDATE gateway_concurrency_leases SET released_at = now() WHERE id = $1::uuid AND released_at IS NULL`, leaseID); err != nil && !errors.Is(err, ErrRateLimited) { return err } } return nil }