easyai-ai-gateway/apps/api/internal/store/rate_limits.go

110 lines
3.0 KiB
Go

package store
import (
"context"
"errors"
)
func (s *Store) ReserveRateLimits(ctx context.Context, taskID string, reservations []RateLimitReservation) (RateLimitResult, error) {
tx, err := s.pool.Begin(ctx)
if err != nil {
return RateLimitResult{}, err
}
defer tx.Rollback(ctx)
result := RateLimitResult{}
for _, reservation := range reservations {
if reservation.Limit <= 0 || reservation.Amount <= 0 {
continue
}
if reservation.Metric == "" || reservation.Amount > reservation.Limit {
return RateLimitResult{}, ErrRateLimited
}
if reservation.WindowSeconds <= 0 {
reservation.WindowSeconds = 60
}
if reservation.Metric == "concurrent" {
if reservation.LeaseTTLSeconds <= 0 {
reservation.LeaseTTLSeconds = 120
}
var active float64
if err := tx.QueryRow(ctx, `
SELECT COALESCE(SUM(lease_value), 0)::float8
FROM gateway_concurrency_leases
WHERE scope_type = $1
AND scope_key = $2
AND released_at IS NULL
AND expires_at > now()`,
reservation.ScopeType,
reservation.ScopeKey,
).Scan(&active); err != nil {
return RateLimitResult{}, err
}
if active+reservation.Amount > reservation.Limit {
return RateLimitResult{}, ErrRateLimited
}
var leaseID string
if err := tx.QueryRow(ctx, `
INSERT INTO gateway_concurrency_leases (task_id, scope_type, scope_key, lease_value, expires_at)
VALUES ($1::uuid, $2, $3, $4, now() + ($5::int * interval '1 second'))
RETURNING id::text`,
taskID,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Amount,
reservation.LeaseTTLSeconds,
).Scan(&leaseID); err != nil {
return RateLimitResult{}, err
}
result.LeaseIDs = append(result.LeaseIDs, leaseID)
continue
}
tag, err := tx.Exec(ctx, `
INSERT INTO gateway_rate_limit_counters (
scope_type, scope_key, metric, window_start, limit_value, used_value, reserved_value, reset_at
)
VALUES (
$1, $2, $3, date_trunc('minute', now()), $4, $5, 0,
date_trunc('minute', now()) + ($6::int * interval '1 second')
)
ON CONFLICT (scope_type, scope_key, metric, window_start) DO UPDATE
SET limit_value = EXCLUDED.limit_value,
used_value = gateway_rate_limit_counters.used_value + EXCLUDED.used_value,
reset_at = EXCLUDED.reset_at,
updated_at = now()
WHERE gateway_rate_limit_counters.used_value + EXCLUDED.used_value <= EXCLUDED.limit_value`,
reservation.ScopeType,
reservation.ScopeKey,
reservation.Metric,
reservation.Limit,
reservation.Amount,
reservation.WindowSeconds,
)
if err != nil {
return RateLimitResult{}, err
}
if tag.RowsAffected() == 0 {
return RateLimitResult{}, ErrRateLimited
}
}
return result, tx.Commit(ctx)
}
func (s *Store) ReleaseConcurrencyLeases(ctx context.Context, leaseIDs []string) error {
if len(leaseIDs) == 0 {
return nil
}
for _, leaseID := range leaseIDs {
if leaseID == "" {
continue
}
if _, err := s.pool.Exec(ctx, `
UPDATE gateway_concurrency_leases
SET released_at = now()
WHERE id = $1::uuid AND released_at IS NULL`, leaseID); err != nil && !errors.Is(err, ErrRateLimited) {
return err
}
}
return nil
}