easyai-ai-gateway/apps/api/internal/runner/service.go

700 lines
26 KiB
Go

package runner
import (
"context"
"errors"
"fmt"
"log/slog"
"strconv"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/jackc/pgx/v5"
"github.com/riverqueue/river"
)
type Service struct {
cfg config.Config
store *store.Store
logger *slog.Logger
clients map[string]clients.Client
httpClients *httpClientCache
riverClient *river.Client[pgx.Tx]
}
type Result struct {
Task store.GatewayTask
Output map[string]any
}
var ErrTaskQueued = errors.New("task queued")
type TaskQueuedError struct {
Delay time.Duration
}
func (e *TaskQueuedError) Error() string {
return ErrTaskQueued.Error()
}
func (e *TaskQueuedError) Is(target error) bool {
return target == ErrTaskQueued
}
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
httpClients := newHTTPClientCache()
return &Service{
cfg: cfg,
store: db,
logger: logger,
clients: map[string]clients.Client{
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
"simulation": clients.SimulationClient{},
},
httpClients: httpClients,
}
}
func (s *Service) Execute(ctx context.Context, task store.GatewayTask, user *auth.User) (Result, error) {
return s.execute(ctx, task, user, nil)
}
func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
return s.execute(ctx, task, user, onDelta)
}
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
executeStartedAt := time.Now()
body := normalizeRequest(task.Kind, task.Request)
modelType := modelTypeFromKind(task.Kind, body)
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
return Result{}, err
}
if task.Status != "running" {
if err := s.emit(ctx, task.ID, "task.running", "running", "starting", 0.12, "task pulled from queue and started", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err
}
}
if err := validateRequest(task.Kind, body); err != nil {
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, err
}
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user)
if err != nil {
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, err
}
if len(candidates) > 0 {
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, body, candidates[0])
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
if errors.Is(err, store.ErrInsufficientWalletBalance) {
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err)
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, err
}
return Result{}, err
}
}
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err
}
runnerPolicy, err := s.store.GetActiveRunnerPolicy(ctx)
if err != nil {
return Result{}, err
}
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
attemptNo := task.AttemptCount
var lastErr error
for index, candidate := range candidates {
if index >= maxPlatforms {
break
}
clientAttempts := clientAttemptsForCandidate(candidate)
var candidateErr error
for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ {
attemptNo++
response, err := s.runCandidate(ctx, task, user, body, candidate, attemptNo, onDelta)
if err == nil {
billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate))
record := buildSuccessRecord(task, user, body, candidate, response, billings, isSimulation(task, candidate))
record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics)
finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{
TaskID: task.ID,
Result: response.Result,
Billings: billings,
RequestID: record.RequestID,
ResolvedModel: record.ResolvedModel,
Usage: record.Usage,
Metrics: record.Metrics,
BillingSummary: record.BillingSummary,
FinalChargeAmount: record.FinalChargeAmount,
ResponseStartedAt: record.ResponseStartedAt,
ResponseFinishedAt: record.ResponseFinishedAt,
ResponseDurationMS: record.ResponseDurationMS,
})
if finishErr != nil {
return Result{}, finishErr
}
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
return Result{}, settleErr
}
if finished.FinalChargeAmount > 0 {
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
"amount": finished.FinalChargeAmount,
"currency": stringFromAny(record.BillingSummary["currency"]),
}, isSimulation(task, candidate)); err != nil {
return Result{}, err
}
}
if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{
"result": response.Result,
"billings": billings,
"usage": record.Usage,
"metrics": record.Metrics,
"billingSummary": record.BillingSummary,
"requestId": record.RequestID,
}, isSimulation(task, candidate)); err != nil {
return Result{}, err
}
return Result{Task: finished, Output: response.Result}, nil
}
lastErr = err
candidateErr = err
retryDecision := retryDecisionForCandidate(candidate, err)
retryAction := "retry"
if !retryDecision.Retry {
retryAction = "stop"
}
if clientAttempt >= clientAttempts {
retryDecision.Retry = false
retryDecision.Reason = "same_client_max_attempts"
retryDecision.Match = policyRuleMatch{
Source: "model_runtime_policy_sets.retry_policy",
Policy: "retryPolicy",
Rule: "maxAttempts",
Value: strconv.Itoa(clientAttempts),
}
retryDecision.Info = failureInfoFromError(err)
retryAction = "stop"
}
if failoverTimeBudgetExceeded(executeStartedAt, maxFailoverDuration) {
retryDecision.Retry = false
retryDecision.Reason = "failover_time_budget_exceeded"
retryDecision.Match = policyRuleMatch{
Source: "gateway_runner_policies.failover_policy",
Policy: "failoverPolicy",
Rule: "maxDurationSeconds",
Value: strconv.Itoa(int(maxFailoverDuration.Seconds())),
}
retryDecision.Info = failureInfoFromError(err)
retryAction = "stop"
}
s.recordAttemptTrace(ctx, task.ID, attemptNo, retryTraceEntry(retryDecision, retryAction, clientAttempt, clientAttempts))
if !retryDecision.Retry {
break
}
if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.45, "retrying same client", addPolicyTracePayload(map[string]any{
"attempt": attemptNo,
"clientAttempt": clientAttempt,
"clientId": candidate.ClientID,
"error": err.Error(),
"reason": retryDecision.Reason,
"scope": "same_client",
}, retryDecision.Match, retryDecision.Info), isSimulation(task, candidate)); err != nil {
return Result{}, err
}
}
if candidateErr == nil || index+1 >= len(candidates) || index+1 >= maxPlatforms {
if candidateErr != nil {
s.applyPriorityDemotePolicy(ctx, task.ID, attemptNo, runnerPolicy, candidate, candidateErr, isSimulation(task, candidate))
decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr)
if decision.Retry {
decision.Retry = false
decision.Action = "stop"
decision.Reason = "no_next_platform"
decision.Match = policyRuleMatch{Source: "runner_candidates", Policy: "candidateSelection", Rule: "candidateCount", Value: strconv.Itoa(len(candidates))}
if index+1 >= maxPlatforms {
decision.Reason = "max_platforms_reached"
decision.Match = policyRuleMatch{Source: "gateway_runner_policies.failover_policy", Policy: "failoverPolicy", Rule: "maxPlatforms", Value: strconv.Itoa(maxPlatforms)}
}
}
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision))
}
break
}
s.applyPriorityDemotePolicy(ctx, task.ID, attemptNo, runnerPolicy, candidate, candidateErr, isSimulation(task, candidate))
if failoverTimeBudgetExceeded(executeStartedAt, maxFailoverDuration) {
elapsedSeconds := int(time.Since(executeStartedAt).Seconds())
maxDurationSeconds := int(maxFailoverDuration.Seconds())
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTimeBudgetTraceEntry(elapsedSeconds, maxDurationSeconds, failureInfoFromError(candidateErr)))
if err := s.emit(ctx, task.ID, "task.failover.stopped", "running", "retry", 0.55, "failover time budget exceeded", map[string]any{
"elapsedSeconds": elapsedSeconds,
"maxDurationSeconds": maxDurationSeconds,
"scope": "next_platform",
"statusCode": clients.ErrorResponseMetadata(candidateErr).StatusCode,
}, isSimulation(task, candidate)); err != nil {
return Result{}, err
}
break
}
decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr)
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision))
if !decision.Retry {
break
}
s.applyFailoverAction(ctx, task.ID, candidate, decision, isSimulation(task, candidate))
if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.55, "retrying next client", addPolicyTracePayload(map[string]any{
"attempt": attemptNo,
"action": decision.Action,
"error": candidateErr.Error(),
"reason": decision.Reason,
"scope": "next_platform",
}, decision.Match, decision.Info), isSimulation(task, candidate)); err != nil {
return Result{}, err
}
}
code := clients.ErrorCode(lastErr)
message := "task failed"
if lastErr != nil {
message = lastErr.Error()
}
if task.AsyncMode && ctx.Err() != nil {
queued, queueErr := s.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
if queueErr != nil {
return Result{}, queueErr
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0}
}
if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) {
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr)
if queueErr != nil {
return Result{}, queueErr
}
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
}
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr)
if err != nil {
return Result{}, err
}
return Result{Task: failed, Output: failed.Result}, lastErr
}
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) {
simulated := isSimulation(task, candidate)
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
return clients.Response{}, fmt.Errorf("emit attempt started: %w", err)
}
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
TaskID: task.ID,
AttemptNo: attemptNo,
PlatformID: candidate.PlatformID,
PlatformModelID: candidate.PlatformModelID,
ClientID: candidate.ClientID,
QueueKey: candidate.QueueKey,
Status: "running",
Simulated: simulated,
RequestSnapshot: body,
Metrics: attemptMetrics(candidate, attemptNo, simulated),
})
if err != nil {
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
}
reservations := s.rateLimitReservations(ctx, user, candidate, body)
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations)
if err != nil {
retryable := store.RateLimitRetryable(err)
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: retryable}
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: retryable,
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": retryable, "retryAfterMs": localRateLimitRetryAfter(err).Milliseconds(), "trace": []any{failureTraceEntry(clientErr, retryable)}}),
ErrorCode: "rate_limit",
ErrorMessage: err.Error(),
})
return clients.Response{}, &localRateLimitError{clientErr: clientErr, cause: err, retryAfter: localRateLimitRetryAfter(err)}
}
rateReservationsFinalized := false
defer func() {
if !rateReservationsFinalized {
_ = s.store.ReleaseRateLimitReservations(context.WithoutCancel(ctx), limitResult.Reservations, "attempt_failed")
}
}()
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
return clients.Response{}, fmt.Errorf("record client assignment: %w", err)
}
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
requestHTTPClient, err := s.httpClientForCandidate(candidate, simulated)
if err != nil {
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: false,
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}),
ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(),
})
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
}
client := s.clientFor(candidate, simulated)
callStartedAt := time.Now()
response, err := client.Run(ctx, clients.Request{
Kind: task.Kind,
ModelType: candidate.ModelType,
Model: task.Model,
Body: body,
Candidate: candidate,
HTTPClient: requestHTTPClient,
RemoteTaskID: task.RemoteTaskID,
RemoteTaskPayload: task.RemoteTaskPayload,
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
if strings.TrimSpace(remoteTaskID) == "" {
return nil
}
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
},
Stream: boolFromMap(body, "stream"),
StreamDelta: onDelta,
})
callFinishedAt := time.Now()
if response.ResponseStartedAt.IsZero() {
response.ResponseStartedAt = callStartedAt
}
if response.ResponseFinishedAt.IsZero() {
response.ResponseFinishedAt = callFinishedAt
}
if response.ResponseDurationMS == 0 {
response.ResponseDurationMS = response.ResponseFinishedAt.Sub(response.ResponseStartedAt).Milliseconds()
if response.ResponseDurationMS < 0 {
response.ResponseDurationMS = 0
}
}
if err != nil {
retryable := clients.IsRetryable(err)
requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(err, simulated)
if responseStartedAt.IsZero() {
responseStartedAt = callStartedAt
}
if responseFinishedAt.IsZero() {
responseFinishedAt = callFinishedAt
}
if responseDurationMS == 0 {
responseDurationMS = responseFinishedAt.Sub(responseStartedAt).Milliseconds()
if responseDurationMS < 0 {
responseDurationMS = 0
}
}
metrics = mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), metrics)
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: retryable,
RequestID: requestID,
Metrics: metrics,
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS,
ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(),
})
_ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "statusCode": clients.ErrorResponseMetadata(err).StatusCode, "metrics": metrics}, simulated)
s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated)
return clients.Response{}, err
}
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
if err != nil {
metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), map[string]any{
"error": err.Error(),
"retryable": clients.IsRetryable(err),
"trace": []any{failureTraceEntry(err, clients.IsRetryable(err))},
})
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: clients.IsRetryable(err),
RequestID: response.RequestID,
Usage: usageToMap(response.Usage),
Metrics: metrics,
ResponseSnapshot: response.Result,
ResponseStartedAt: response.ResponseStartedAt,
ResponseFinishedAt: response.ResponseFinishedAt,
ResponseDurationMS: response.ResponseDurationMS,
ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(),
})
return clients.Response{}, err
}
response.Result = uploadedResult
for _, progress := range response.Progress {
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
}
}
if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil {
return clients.Response{}, fmt.Errorf("commit rate limit reservations: %w", err)
}
rateReservationsFinalized = true
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "succeeded",
RequestID: response.RequestID,
Usage: usageToMap(response.Usage),
Metrics: taskMetrics(task, user, body, candidate, response, simulated),
ResponseSnapshot: response.Result,
ResponseStartedAt: response.ResponseStartedAt,
ResponseFinishedAt: response.ResponseFinishedAt,
ResponseDurationMS: response.ResponseDurationMS,
}); err != nil {
return clients.Response{}, fmt.Errorf("finish task attempt: %w", err)
}
return response, nil
}
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
if simulated {
return s.clients["simulation"]
}
key := strings.ToLower(strings.TrimSpace(candidate.SpecType))
if key == "" {
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
}
if client, ok := s.clients[key]; ok {
return client
}
return s.clients["openai"]
}
func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool, cause error) (store.GatewayTask, error) {
requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(cause, simulated)
metrics = s.withAttemptHistory(ctx, taskID, metrics)
failed, err := s.store.FinishTaskFailure(ctx, store.FinishTaskFailureInput{
TaskID: taskID,
Code: code,
Message: message,
RequestID: requestID,
Metrics: metrics,
ResponseStartedAt: responseStartedAt,
ResponseFinishedAt: responseFinishedAt,
ResponseDurationMS: responseDurationMS,
})
if err != nil {
return store.GatewayTask{}, err
}
if eventErr := s.emit(ctx, taskID, "task.failed", "failed", "failed", 1, message, map[string]any{"code": code, "requestId": requestID, "metrics": metrics}, simulated); eventErr != nil {
return store.GatewayTask{}, eventErr
}
return failed, nil
}
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) {
delay := localRateLimitRetryAfter(cause)
if delay <= 0 {
delay = 5 * time.Second
}
queued, err := s.store.RequeueTask(ctx, task.ID, delay)
if err != nil {
return store.GatewayTask{}, 0, err
}
payload := map[string]any{
"code": "rate_limit",
"message": cause.Error(),
"retryAfterMs": delay.Milliseconds(),
}
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "rate_limited", 0.2, "task queued by local rate limit", payload, task.RunMode == "simulation"); eventErr != nil {
return store.GatewayTask{}, 0, eventErr
}
return queued, delay, nil
}
func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) {
queued, err := s.store.RequeueTask(ctx, task.ID, 0)
if err != nil {
return store.GatewayTask{}, err
}
payload := map[string]any{"code": "worker_interrupted"}
if task.RemoteTaskID != "" {
payload["remoteTaskId"] = task.RemoteTaskID
}
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "worker_interrupted", 0.2, "async task queued after worker interruption", payload, task.RunMode == "simulation"); eventErr != nil {
return store.GatewayTask{}, eventErr
}
return queued, nil
}
func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any {
attempts, err := s.store.ListTaskAttempts(ctx, taskID)
if err != nil {
s.logger.Warn("list task attempts for metrics failed", "taskID", taskID, "error", err)
return metrics
}
if len(attempts) == 0 {
return metrics
}
metrics = mergeMetrics(metrics)
metrics["attemptCount"] = len(attempts)
metrics["attempts"] = summarizeAttempts(attempts)
return metrics
}
func (s *Service) emit(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) error {
event, err := s.store.AddTaskEvent(ctx, taskID, eventType, status, phase, progress, message, payload, simulated)
if err != nil {
return err
}
if s.cfg.TaskProgressCallbackEnabled {
return s.store.QueueTaskCallback(ctx, event, s.cfg.TaskProgressCallbackURL)
}
return nil
}
func modelTypeFromKind(kind string, body map[string]any) string {
switch kind {
case "chat.completions", "responses":
return "text_generate"
case "images.generations", "images.edits":
if kind == "images.edits" {
return "image_edit"
}
return "image_generate"
case "videos.generations":
if videoRequestHasReferenceImage(body) {
return "image_to_video"
}
return "video_generate"
default:
return "task"
}
}
func videoRequestHasReferenceImage(body map[string]any) bool {
if body == nil {
return false
}
for _, key := range []string{
"image", "images", "image_url", "imageUrl", "image_urls", "imageUrls",
"reference_image", "referenceImage", "first_frame", "firstFrame", "last_frame", "lastFrame",
} {
if hasAnyString(body, key) {
return true
}
}
return false
}
func isTextGenerationKind(kind string) bool {
return kind == "chat.completions" || kind == "responses"
}
func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) bool {
if task.RunMode == "simulation" {
return true
}
return stringFromMap(candidate.Credentials, "mode") == "simulation" || boolFromMap(candidate.PlatformConfig, "testMode")
}
func retryEnabled(candidate store.RuntimeModelCandidate) bool {
policy := effectiveRetryPolicy(candidate)
if enabled, ok := policy["enabled"].(bool); ok {
return enabled
}
return true
}
func clientAttemptsForCandidate(candidate store.RuntimeModelCandidate) int {
if !retryEnabled(candidate) {
return 1
}
if value := intFromPolicy(effectiveRetryPolicy(candidate), "maxAttempts"); value > 0 {
return value
}
return 1
}
func maxPlatformsForCandidates(candidates []store.RuntimeModelCandidate, runnerPolicy store.RunnerPolicy) int {
if len(candidates) == 0 {
return 0
}
maxPlatforms := len(candidates)
if value := intFromPolicy(runnerPolicy.FailoverPolicy, "maxPlatforms"); value > 0 {
if value < maxPlatforms {
maxPlatforms = value
}
} else if maxPlatforms > 99 {
maxPlatforms = 99
}
for _, candidate := range candidates {
if value := intFromPolicy(effectiveFailoverPolicy(runnerPolicy.FailoverPolicy, candidate.RuntimePolicyOverride), "maxPlatforms"); value > 0 && value < maxPlatforms {
maxPlatforms = value
}
}
if maxPlatforms <= 0 {
return 1
}
return maxPlatforms
}
func maxFailoverDurationForCandidates(candidates []store.RuntimeModelCandidate, runnerPolicy store.RunnerPolicy) time.Duration {
seconds := intFromPolicy(runnerPolicy.FailoverPolicy, "maxDurationSeconds")
if seconds <= 0 {
seconds = 600
}
for _, candidate := range candidates {
if value := intFromPolicy(effectiveFailoverPolicy(runnerPolicy.FailoverPolicy, candidate.RuntimePolicyOverride), "maxDurationSeconds"); value > 0 && value < seconds {
seconds = value
}
}
return time.Duration(seconds) * time.Second
}
func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool {
return maxDuration > 0 && time.Since(start) >= maxDuration
}
func normalizeRequest(kind string, body map[string]any) map[string]any {
out := map[string]any{}
for key, value := range body {
out[key] = value
}
if kind == "responses" && out["messages"] == nil && out["input"] != nil {
out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}}
}
return out
}
func validateRequest(kind string, body map[string]any) error {
switch kind {
case "chat.completions":
if body["messages"] == nil {
return errors.New("messages is required")
}
case "responses":
if body["input"] == nil && body["messages"] == nil {
return errors.New("input or messages is required")
}
case "images.generations", "images.edits":
if strings.TrimSpace(stringFromMap(body, "prompt")) == "" {
return errors.New("prompt is required")
}
}
return nil
}