110 lines
3.0 KiB
Go
110 lines
3.0 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
)
|
|
|
|
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, reservations []RateLimitReservation) (RateLimitResult, error) {
|
|
tx, err := s.pool.Begin(ctx)
|
|
if err != nil {
|
|
return RateLimitResult{}, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
result := RateLimitResult{}
|
|
for _, reservation := range reservations {
|
|
if reservation.Limit <= 0 || reservation.Amount <= 0 {
|
|
continue
|
|
}
|
|
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
|
|
return RateLimitResult{}, ErrRateLimited
|
|
}
|
|
if reservation.WindowSeconds <= 0 {
|
|
reservation.WindowSeconds = 60
|
|
}
|
|
if reservation.Metric == "concurrent" {
|
|
if reservation.LeaseTTLSeconds <= 0 {
|
|
reservation.LeaseTTLSeconds = 120
|
|
}
|
|
var active float64
|
|
if err := tx.QueryRow(ctx, `
|
|
SELECT COALESCE(SUM(lease_value), 0)::float8
|
|
FROM gateway_concurrency_leases
|
|
WHERE scope_type = $1
|
|
AND scope_key = $2
|
|
AND released_at IS NULL
|
|
AND expires_at > now()`,
|
|
reservation.ScopeType,
|
|
reservation.ScopeKey,
|
|
).Scan(&active); err != nil {
|
|
return RateLimitResult{}, err
|
|
}
|
|
if active+reservation.Amount > reservation.Limit {
|
|
return RateLimitResult{}, ErrRateLimited
|
|
}
|
|
var leaseID string
|
|
if err := tx.QueryRow(ctx, `
|
|
INSERT INTO gateway_concurrency_leases (task_id, scope_type, scope_key, lease_value, expires_at)
|
|
VALUES ($1::uuid, $2, $3, $4, now() + ($5::int * interval '1 second'))
|
|
RETURNING id::text`,
|
|
taskID,
|
|
reservation.ScopeType,
|
|
reservation.ScopeKey,
|
|
reservation.Amount,
|
|
reservation.LeaseTTLSeconds,
|
|
).Scan(&leaseID); err != nil {
|
|
return RateLimitResult{}, err
|
|
}
|
|
result.LeaseIDs = append(result.LeaseIDs, leaseID)
|
|
continue
|
|
}
|
|
tag, err := tx.Exec(ctx, `
|
|
INSERT INTO gateway_rate_limit_counters (
|
|
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
|
|
)
|
|
VALUES (
|
|
$1, $2, $3, date_trunc('minute', now()), $4, $5, 0,
|
|
date_trunc('minute', now()) + ($6::int * interval '1 second')
|
|
)
|
|
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
|
|
SET limit_value = EXCLUDED.limit_value,
|
|
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
|
|
reset_at = EXCLUDED.reset_at,
|
|
updated_at = now()
|
|
WHERE gateway_rate_limit_counters.used_value + EXCLUDED.used_value <= EXCLUDED.limit_value`,
|
|
reservation.ScopeType,
|
|
reservation.ScopeKey,
|
|
reservation.Metric,
|
|
reservation.Limit,
|
|
reservation.Amount,
|
|
reservation.WindowSeconds,
|
|
)
|
|
if err != nil {
|
|
return RateLimitResult{}, err
|
|
}
|
|
if tag.RowsAffected() == 0 {
|
|
return RateLimitResult{}, ErrRateLimited
|
|
}
|
|
}
|
|
return result, tx.Commit(ctx)
|
|
}
|
|
|
|
func (s *Store) ReleaseConcurrencyLeases(ctx context.Context, leaseIDs []string) error {
|
|
if len(leaseIDs) == 0 {
|
|
return nil
|
|
}
|
|
for _, leaseID := range leaseIDs {
|
|
if leaseID == "" {
|
|
continue
|
|
}
|
|
if _, err := s.pool.Exec(ctx, `
|
|
UPDATE gateway_concurrency_leases
|
|
SET released_at = now()
|
|
WHERE id = $1::uuid AND released_at IS NULL`, leaseID); err != nil && !errors.Is(err, ErrRateLimited) {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|