700 lines
26 KiB
Go
700 lines
26 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/riverqueue/river"
|
|
)
|
|
|
|
type Service struct {
|
|
cfg config.Config
|
|
store *store.Store
|
|
logger *slog.Logger
|
|
clients map[string]clients.Client
|
|
httpClients *httpClientCache
|
|
riverClient *river.Client[pgx.Tx]
|
|
}
|
|
|
|
type Result struct {
|
|
Task store.GatewayTask
|
|
Output map[string]any
|
|
}
|
|
|
|
var ErrTaskQueued = errors.New("task queued")
|
|
|
|
type TaskQueuedError struct {
|
|
Delay time.Duration
|
|
}
|
|
|
|
func (e *TaskQueuedError) Error() string {
|
|
return ErrTaskQueued.Error()
|
|
}
|
|
|
|
func (e *TaskQueuedError) Is(target error) bool {
|
|
return target == ErrTaskQueued
|
|
}
|
|
|
|
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
|
|
httpClients := newHTTPClientCache()
|
|
return &Service{
|
|
cfg: cfg,
|
|
store: db,
|
|
logger: logger,
|
|
clients: map[string]clients.Client{
|
|
"openai": clients.OpenAIClient{HTTPClient: httpClients.none},
|
|
"gemini": clients.GeminiClient{HTTPClient: httpClients.none},
|
|
"volces": clients.VolcesClient{HTTPClient: httpClients.none},
|
|
"simulation": clients.SimulationClient{},
|
|
},
|
|
httpClients: httpClients,
|
|
}
|
|
}
|
|
|
|
func (s *Service) Execute(ctx context.Context, task store.GatewayTask, user *auth.User) (Result, error) {
|
|
return s.execute(ctx, task, user, nil)
|
|
}
|
|
|
|
func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
|
|
return s.execute(ctx, task, user, onDelta)
|
|
}
|
|
|
|
func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) {
|
|
executeStartedAt := time.Now()
|
|
body := normalizeRequest(task.Kind, task.Request)
|
|
modelType := modelTypeFromKind(task.Kind, body)
|
|
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
|
|
return Result{}, err
|
|
}
|
|
if task.Status != "running" {
|
|
if err := s.emit(ctx, task.ID, "task.running", "running", "starting", 0.12, "task pulled from queue and started", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
|
|
return Result{}, err
|
|
}
|
|
}
|
|
if err := validateRequest(task.Kind, body); err != nil {
|
|
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err)
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, err
|
|
}
|
|
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user)
|
|
if err != nil {
|
|
failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err)
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, err
|
|
}
|
|
if len(candidates) > 0 {
|
|
estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, body, candidates[0])
|
|
if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil {
|
|
if errors.Is(err, store.ErrInsufficientWalletBalance) {
|
|
failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err)
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, err
|
|
}
|
|
return Result{}, err
|
|
}
|
|
}
|
|
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
|
|
return Result{}, err
|
|
}
|
|
|
|
runnerPolicy, err := s.store.GetActiveRunnerPolicy(ctx)
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy)
|
|
maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy)
|
|
attemptNo := task.AttemptCount
|
|
var lastErr error
|
|
for index, candidate := range candidates {
|
|
if index >= maxPlatforms {
|
|
break
|
|
}
|
|
clientAttempts := clientAttemptsForCandidate(candidate)
|
|
var candidateErr error
|
|
for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ {
|
|
attemptNo++
|
|
response, err := s.runCandidate(ctx, task, user, body, candidate, attemptNo, onDelta)
|
|
if err == nil {
|
|
billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate))
|
|
record := buildSuccessRecord(task, user, body, candidate, response, billings, isSimulation(task, candidate))
|
|
record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics)
|
|
finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{
|
|
TaskID: task.ID,
|
|
Result: response.Result,
|
|
Billings: billings,
|
|
RequestID: record.RequestID,
|
|
ResolvedModel: record.ResolvedModel,
|
|
Usage: record.Usage,
|
|
Metrics: record.Metrics,
|
|
BillingSummary: record.BillingSummary,
|
|
FinalChargeAmount: record.FinalChargeAmount,
|
|
ResponseStartedAt: record.ResponseStartedAt,
|
|
ResponseFinishedAt: record.ResponseFinishedAt,
|
|
ResponseDurationMS: record.ResponseDurationMS,
|
|
})
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil {
|
|
return Result{}, settleErr
|
|
}
|
|
if finished.FinalChargeAmount > 0 {
|
|
if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{
|
|
"amount": finished.FinalChargeAmount,
|
|
"currency": stringFromAny(record.BillingSummary["currency"]),
|
|
}, isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
}
|
|
if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{
|
|
"result": response.Result,
|
|
"billings": billings,
|
|
"usage": record.Usage,
|
|
"metrics": record.Metrics,
|
|
"billingSummary": record.BillingSummary,
|
|
"requestId": record.RequestID,
|
|
}, isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
return Result{Task: finished, Output: response.Result}, nil
|
|
}
|
|
lastErr = err
|
|
candidateErr = err
|
|
retryDecision := retryDecisionForCandidate(candidate, err)
|
|
retryAction := "retry"
|
|
if !retryDecision.Retry {
|
|
retryAction = "stop"
|
|
}
|
|
if clientAttempt >= clientAttempts {
|
|
retryDecision.Retry = false
|
|
retryDecision.Reason = "same_client_max_attempts"
|
|
retryDecision.Match = policyRuleMatch{
|
|
Source: "model_runtime_policy_sets.retry_policy",
|
|
Policy: "retryPolicy",
|
|
Rule: "maxAttempts",
|
|
Value: strconv.Itoa(clientAttempts),
|
|
}
|
|
retryDecision.Info = failureInfoFromError(err)
|
|
retryAction = "stop"
|
|
}
|
|
if failoverTimeBudgetExceeded(executeStartedAt, maxFailoverDuration) {
|
|
retryDecision.Retry = false
|
|
retryDecision.Reason = "failover_time_budget_exceeded"
|
|
retryDecision.Match = policyRuleMatch{
|
|
Source: "gateway_runner_policies.failover_policy",
|
|
Policy: "failoverPolicy",
|
|
Rule: "maxDurationSeconds",
|
|
Value: strconv.Itoa(int(maxFailoverDuration.Seconds())),
|
|
}
|
|
retryDecision.Info = failureInfoFromError(err)
|
|
retryAction = "stop"
|
|
}
|
|
s.recordAttemptTrace(ctx, task.ID, attemptNo, retryTraceEntry(retryDecision, retryAction, clientAttempt, clientAttempts))
|
|
if !retryDecision.Retry {
|
|
break
|
|
}
|
|
if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.45, "retrying same client", addPolicyTracePayload(map[string]any{
|
|
"attempt": attemptNo,
|
|
"clientAttempt": clientAttempt,
|
|
"clientId": candidate.ClientID,
|
|
"error": err.Error(),
|
|
"reason": retryDecision.Reason,
|
|
"scope": "same_client",
|
|
}, retryDecision.Match, retryDecision.Info), isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
}
|
|
if candidateErr == nil || index+1 >= len(candidates) || index+1 >= maxPlatforms {
|
|
if candidateErr != nil {
|
|
s.applyPriorityDemotePolicy(ctx, task.ID, attemptNo, runnerPolicy, candidate, candidateErr, isSimulation(task, candidate))
|
|
decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr)
|
|
if decision.Retry {
|
|
decision.Retry = false
|
|
decision.Action = "stop"
|
|
decision.Reason = "no_next_platform"
|
|
decision.Match = policyRuleMatch{Source: "runner_candidates", Policy: "candidateSelection", Rule: "candidateCount", Value: strconv.Itoa(len(candidates))}
|
|
if index+1 >= maxPlatforms {
|
|
decision.Reason = "max_platforms_reached"
|
|
decision.Match = policyRuleMatch{Source: "gateway_runner_policies.failover_policy", Policy: "failoverPolicy", Rule: "maxPlatforms", Value: strconv.Itoa(maxPlatforms)}
|
|
}
|
|
}
|
|
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision))
|
|
}
|
|
break
|
|
}
|
|
s.applyPriorityDemotePolicy(ctx, task.ID, attemptNo, runnerPolicy, candidate, candidateErr, isSimulation(task, candidate))
|
|
if failoverTimeBudgetExceeded(executeStartedAt, maxFailoverDuration) {
|
|
elapsedSeconds := int(time.Since(executeStartedAt).Seconds())
|
|
maxDurationSeconds := int(maxFailoverDuration.Seconds())
|
|
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTimeBudgetTraceEntry(elapsedSeconds, maxDurationSeconds, failureInfoFromError(candidateErr)))
|
|
if err := s.emit(ctx, task.ID, "task.failover.stopped", "running", "retry", 0.55, "failover time budget exceeded", map[string]any{
|
|
"elapsedSeconds": elapsedSeconds,
|
|
"maxDurationSeconds": maxDurationSeconds,
|
|
"scope": "next_platform",
|
|
"statusCode": clients.ErrorResponseMetadata(candidateErr).StatusCode,
|
|
}, isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
break
|
|
}
|
|
decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr)
|
|
s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision))
|
|
if !decision.Retry {
|
|
break
|
|
}
|
|
s.applyFailoverAction(ctx, task.ID, candidate, decision, isSimulation(task, candidate))
|
|
if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.55, "retrying next client", addPolicyTracePayload(map[string]any{
|
|
"attempt": attemptNo,
|
|
"action": decision.Action,
|
|
"error": candidateErr.Error(),
|
|
"reason": decision.Reason,
|
|
"scope": "next_platform",
|
|
}, decision.Match, decision.Info), isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
}
|
|
code := clients.ErrorCode(lastErr)
|
|
message := "task failed"
|
|
if lastErr != nil {
|
|
message = lastErr.Error()
|
|
}
|
|
if task.AsyncMode && ctx.Err() != nil {
|
|
queued, queueErr := s.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
|
|
if queueErr != nil {
|
|
return Result{}, queueErr
|
|
}
|
|
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: 0}
|
|
}
|
|
if task.AsyncMode && errors.Is(lastErr, store.ErrRateLimited) && store.RateLimitRetryable(lastErr) {
|
|
queued, delay, queueErr := s.requeueRateLimitedTask(ctx, task, lastErr)
|
|
if queueErr != nil {
|
|
return Result{}, queueErr
|
|
}
|
|
return Result{Task: queued, Output: queued.Result}, &TaskQueuedError{Delay: delay}
|
|
}
|
|
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr)
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, lastErr
|
|
}
|
|
|
|
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) {
|
|
simulated := isSimulation(task, candidate)
|
|
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
|
|
return clients.Response{}, fmt.Errorf("emit attempt started: %w", err)
|
|
}
|
|
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
|
|
TaskID: task.ID,
|
|
AttemptNo: attemptNo,
|
|
PlatformID: candidate.PlatformID,
|
|
PlatformModelID: candidate.PlatformModelID,
|
|
ClientID: candidate.ClientID,
|
|
QueueKey: candidate.QueueKey,
|
|
Status: "running",
|
|
Simulated: simulated,
|
|
RequestSnapshot: body,
|
|
Metrics: attemptMetrics(candidate, attemptNo, simulated),
|
|
})
|
|
if err != nil {
|
|
return clients.Response{}, fmt.Errorf("create task attempt: %w", err)
|
|
}
|
|
reservations := s.rateLimitReservations(ctx, user, candidate, body)
|
|
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations)
|
|
if err != nil {
|
|
retryable := store.RateLimitRetryable(err)
|
|
clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: retryable}
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "failed",
|
|
Retryable: retryable,
|
|
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": retryable, "retryAfterMs": localRateLimitRetryAfter(err).Milliseconds(), "trace": []any{failureTraceEntry(clientErr, retryable)}}),
|
|
ErrorCode: "rate_limit",
|
|
ErrorMessage: err.Error(),
|
|
})
|
|
return clients.Response{}, &localRateLimitError{clientErr: clientErr, cause: err, retryAfter: localRateLimitRetryAfter(err)}
|
|
}
|
|
rateReservationsFinalized := false
|
|
defer func() {
|
|
if !rateReservationsFinalized {
|
|
_ = s.store.ReleaseRateLimitReservations(context.WithoutCancel(ctx), limitResult.Reservations, "attempt_failed")
|
|
}
|
|
}()
|
|
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
|
|
|
|
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
|
|
return clients.Response{}, fmt.Errorf("record client assignment: %w", err)
|
|
}
|
|
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
|
|
|
|
requestHTTPClient, err := s.httpClientForCandidate(candidate, simulated)
|
|
if err != nil {
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "failed",
|
|
Retryable: false,
|
|
Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}),
|
|
ErrorCode: clients.ErrorCode(err),
|
|
ErrorMessage: err.Error(),
|
|
})
|
|
return clients.Response{}, fmt.Errorf("prepare http client: %w", err)
|
|
}
|
|
client := s.clientFor(candidate, simulated)
|
|
callStartedAt := time.Now()
|
|
response, err := client.Run(ctx, clients.Request{
|
|
Kind: task.Kind,
|
|
ModelType: candidate.ModelType,
|
|
Model: task.Model,
|
|
Body: body,
|
|
Candidate: candidate,
|
|
HTTPClient: requestHTTPClient,
|
|
RemoteTaskID: task.RemoteTaskID,
|
|
RemoteTaskPayload: task.RemoteTaskPayload,
|
|
OnRemoteTaskSubmitted: func(remoteTaskID string, payload map[string]any) error {
|
|
if strings.TrimSpace(remoteTaskID) == "" {
|
|
return nil
|
|
}
|
|
return s.store.SetTaskRemoteTask(context.WithoutCancel(ctx), task.ID, attemptID, remoteTaskID, payload)
|
|
},
|
|
Stream: boolFromMap(body, "stream"),
|
|
StreamDelta: onDelta,
|
|
})
|
|
callFinishedAt := time.Now()
|
|
if response.ResponseStartedAt.IsZero() {
|
|
response.ResponseStartedAt = callStartedAt
|
|
}
|
|
if response.ResponseFinishedAt.IsZero() {
|
|
response.ResponseFinishedAt = callFinishedAt
|
|
}
|
|
if response.ResponseDurationMS == 0 {
|
|
response.ResponseDurationMS = response.ResponseFinishedAt.Sub(response.ResponseStartedAt).Milliseconds()
|
|
if response.ResponseDurationMS < 0 {
|
|
response.ResponseDurationMS = 0
|
|
}
|
|
}
|
|
if err != nil {
|
|
retryable := clients.IsRetryable(err)
|
|
requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(err, simulated)
|
|
if responseStartedAt.IsZero() {
|
|
responseStartedAt = callStartedAt
|
|
}
|
|
if responseFinishedAt.IsZero() {
|
|
responseFinishedAt = callFinishedAt
|
|
}
|
|
if responseDurationMS == 0 {
|
|
responseDurationMS = responseFinishedAt.Sub(responseStartedAt).Milliseconds()
|
|
if responseDurationMS < 0 {
|
|
responseDurationMS = 0
|
|
}
|
|
}
|
|
metrics = mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), metrics)
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "failed",
|
|
Retryable: retryable,
|
|
RequestID: requestID,
|
|
Metrics: metrics,
|
|
ResponseStartedAt: responseStartedAt,
|
|
ResponseFinishedAt: responseFinishedAt,
|
|
ResponseDurationMS: responseDurationMS,
|
|
ErrorCode: clients.ErrorCode(err),
|
|
ErrorMessage: err.Error(),
|
|
})
|
|
_ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "statusCode": clients.ErrorResponseMetadata(err).StatusCode, "metrics": metrics}, simulated)
|
|
s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated)
|
|
return clients.Response{}, err
|
|
}
|
|
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
|
|
if err != nil {
|
|
metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), map[string]any{
|
|
"error": err.Error(),
|
|
"retryable": clients.IsRetryable(err),
|
|
"trace": []any{failureTraceEntry(err, clients.IsRetryable(err))},
|
|
})
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "failed",
|
|
Retryable: clients.IsRetryable(err),
|
|
RequestID: response.RequestID,
|
|
Usage: usageToMap(response.Usage),
|
|
Metrics: metrics,
|
|
ResponseSnapshot: response.Result,
|
|
ResponseStartedAt: response.ResponseStartedAt,
|
|
ResponseFinishedAt: response.ResponseFinishedAt,
|
|
ResponseDurationMS: response.ResponseDurationMS,
|
|
ErrorCode: clients.ErrorCode(err),
|
|
ErrorMessage: err.Error(),
|
|
})
|
|
return clients.Response{}, err
|
|
}
|
|
response.Result = uploadedResult
|
|
for _, progress := range response.Progress {
|
|
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
|
|
return clients.Response{}, fmt.Errorf("emit task progress: %w", err)
|
|
}
|
|
}
|
|
if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil {
|
|
return clients.Response{}, fmt.Errorf("commit rate limit reservations: %w", err)
|
|
}
|
|
rateReservationsFinalized = true
|
|
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "succeeded",
|
|
RequestID: response.RequestID,
|
|
Usage: usageToMap(response.Usage),
|
|
Metrics: taskMetrics(task, user, body, candidate, response, simulated),
|
|
ResponseSnapshot: response.Result,
|
|
ResponseStartedAt: response.ResponseStartedAt,
|
|
ResponseFinishedAt: response.ResponseFinishedAt,
|
|
ResponseDurationMS: response.ResponseDurationMS,
|
|
}); err != nil {
|
|
return clients.Response{}, fmt.Errorf("finish task attempt: %w", err)
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
|
|
if simulated {
|
|
return s.clients["simulation"]
|
|
}
|
|
key := strings.ToLower(strings.TrimSpace(candidate.SpecType))
|
|
if key == "" {
|
|
key = strings.ToLower(strings.TrimSpace(candidate.Provider))
|
|
}
|
|
if client, ok := s.clients[key]; ok {
|
|
return client
|
|
}
|
|
return s.clients["openai"]
|
|
}
|
|
|
|
func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool, cause error) (store.GatewayTask, error) {
|
|
requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(cause, simulated)
|
|
metrics = s.withAttemptHistory(ctx, taskID, metrics)
|
|
failed, err := s.store.FinishTaskFailure(ctx, store.FinishTaskFailureInput{
|
|
TaskID: taskID,
|
|
Code: code,
|
|
Message: message,
|
|
RequestID: requestID,
|
|
Metrics: metrics,
|
|
ResponseStartedAt: responseStartedAt,
|
|
ResponseFinishedAt: responseFinishedAt,
|
|
ResponseDurationMS: responseDurationMS,
|
|
})
|
|
if err != nil {
|
|
return store.GatewayTask{}, err
|
|
}
|
|
if eventErr := s.emit(ctx, taskID, "task.failed", "failed", "failed", 1, message, map[string]any{"code": code, "requestId": requestID, "metrics": metrics}, simulated); eventErr != nil {
|
|
return store.GatewayTask{}, eventErr
|
|
}
|
|
return failed, nil
|
|
}
|
|
|
|
func (s *Service) requeueRateLimitedTask(ctx context.Context, task store.GatewayTask, cause error) (store.GatewayTask, time.Duration, error) {
|
|
delay := localRateLimitRetryAfter(cause)
|
|
if delay <= 0 {
|
|
delay = 5 * time.Second
|
|
}
|
|
queued, err := s.store.RequeueTask(ctx, task.ID, delay)
|
|
if err != nil {
|
|
return store.GatewayTask{}, 0, err
|
|
}
|
|
payload := map[string]any{
|
|
"code": "rate_limit",
|
|
"message": cause.Error(),
|
|
"retryAfterMs": delay.Milliseconds(),
|
|
}
|
|
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "rate_limited", 0.2, "task queued by local rate limit", payload, task.RunMode == "simulation"); eventErr != nil {
|
|
return store.GatewayTask{}, 0, eventErr
|
|
}
|
|
return queued, delay, nil
|
|
}
|
|
|
|
func (s *Service) requeueInterruptedAsyncTask(ctx context.Context, task store.GatewayTask) (store.GatewayTask, error) {
|
|
queued, err := s.store.RequeueTask(ctx, task.ID, 0)
|
|
if err != nil {
|
|
return store.GatewayTask{}, err
|
|
}
|
|
payload := map[string]any{"code": "worker_interrupted"}
|
|
if task.RemoteTaskID != "" {
|
|
payload["remoteTaskId"] = task.RemoteTaskID
|
|
}
|
|
if eventErr := s.emit(ctx, task.ID, "task.queued", "queued", "worker_interrupted", 0.2, "async task queued after worker interruption", payload, task.RunMode == "simulation"); eventErr != nil {
|
|
return store.GatewayTask{}, eventErr
|
|
}
|
|
return queued, nil
|
|
}
|
|
|
|
func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any {
|
|
attempts, err := s.store.ListTaskAttempts(ctx, taskID)
|
|
if err != nil {
|
|
s.logger.Warn("list task attempts for metrics failed", "taskID", taskID, "error", err)
|
|
return metrics
|
|
}
|
|
if len(attempts) == 0 {
|
|
return metrics
|
|
}
|
|
metrics = mergeMetrics(metrics)
|
|
metrics["attemptCount"] = len(attempts)
|
|
metrics["attempts"] = summarizeAttempts(attempts)
|
|
return metrics
|
|
}
|
|
|
|
func (s *Service) emit(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) error {
|
|
event, err := s.store.AddTaskEvent(ctx, taskID, eventType, status, phase, progress, message, payload, simulated)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if s.cfg.TaskProgressCallbackEnabled {
|
|
return s.store.QueueTaskCallback(ctx, event, s.cfg.TaskProgressCallbackURL)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func modelTypeFromKind(kind string, body map[string]any) string {
|
|
switch kind {
|
|
case "chat.completions", "responses":
|
|
return "text_generate"
|
|
case "images.generations", "images.edits":
|
|
if kind == "images.edits" {
|
|
return "image_edit"
|
|
}
|
|
return "image_generate"
|
|
case "videos.generations":
|
|
if videoRequestHasReferenceImage(body) {
|
|
return "image_to_video"
|
|
}
|
|
return "video_generate"
|
|
default:
|
|
return "task"
|
|
}
|
|
}
|
|
|
|
func videoRequestHasReferenceImage(body map[string]any) bool {
|
|
if body == nil {
|
|
return false
|
|
}
|
|
for _, key := range []string{
|
|
"image", "images", "image_url", "imageUrl", "image_urls", "imageUrls",
|
|
"reference_image", "referenceImage", "first_frame", "firstFrame", "last_frame", "lastFrame",
|
|
} {
|
|
if hasAnyString(body, key) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func isTextGenerationKind(kind string) bool {
|
|
return kind == "chat.completions" || kind == "responses"
|
|
}
|
|
|
|
func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) bool {
|
|
if task.RunMode == "simulation" {
|
|
return true
|
|
}
|
|
return stringFromMap(candidate.Credentials, "mode") == "simulation" || boolFromMap(candidate.PlatformConfig, "testMode")
|
|
}
|
|
|
|
func retryEnabled(candidate store.RuntimeModelCandidate) bool {
|
|
policy := effectiveRetryPolicy(candidate)
|
|
if enabled, ok := policy["enabled"].(bool); ok {
|
|
return enabled
|
|
}
|
|
return true
|
|
}
|
|
|
|
func clientAttemptsForCandidate(candidate store.RuntimeModelCandidate) int {
|
|
if !retryEnabled(candidate) {
|
|
return 1
|
|
}
|
|
if value := intFromPolicy(effectiveRetryPolicy(candidate), "maxAttempts"); value > 0 {
|
|
return value
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func maxPlatformsForCandidates(candidates []store.RuntimeModelCandidate, runnerPolicy store.RunnerPolicy) int {
|
|
if len(candidates) == 0 {
|
|
return 0
|
|
}
|
|
maxPlatforms := len(candidates)
|
|
if value := intFromPolicy(runnerPolicy.FailoverPolicy, "maxPlatforms"); value > 0 {
|
|
if value < maxPlatforms {
|
|
maxPlatforms = value
|
|
}
|
|
} else if maxPlatforms > 99 {
|
|
maxPlatforms = 99
|
|
}
|
|
for _, candidate := range candidates {
|
|
if value := intFromPolicy(effectiveFailoverPolicy(runnerPolicy.FailoverPolicy, candidate.RuntimePolicyOverride), "maxPlatforms"); value > 0 && value < maxPlatforms {
|
|
maxPlatforms = value
|
|
}
|
|
}
|
|
if maxPlatforms <= 0 {
|
|
return 1
|
|
}
|
|
return maxPlatforms
|
|
}
|
|
|
|
func maxFailoverDurationForCandidates(candidates []store.RuntimeModelCandidate, runnerPolicy store.RunnerPolicy) time.Duration {
|
|
seconds := intFromPolicy(runnerPolicy.FailoverPolicy, "maxDurationSeconds")
|
|
if seconds <= 0 {
|
|
seconds = 600
|
|
}
|
|
for _, candidate := range candidates {
|
|
if value := intFromPolicy(effectiveFailoverPolicy(runnerPolicy.FailoverPolicy, candidate.RuntimePolicyOverride), "maxDurationSeconds"); value > 0 && value < seconds {
|
|
seconds = value
|
|
}
|
|
}
|
|
return time.Duration(seconds) * time.Second
|
|
}
|
|
|
|
func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool {
|
|
return maxDuration > 0 && time.Since(start) >= maxDuration
|
|
}
|
|
|
|
func normalizeRequest(kind string, body map[string]any) map[string]any {
|
|
out := map[string]any{}
|
|
for key, value := range body {
|
|
out[key] = value
|
|
}
|
|
if kind == "responses" && out["messages"] == nil && out["input"] != nil {
|
|
out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func validateRequest(kind string, body map[string]any) error {
|
|
switch kind {
|
|
case "chat.completions":
|
|
if body["messages"] == nil {
|
|
return errors.New("messages is required")
|
|
}
|
|
case "responses":
|
|
if body["input"] == nil && body["messages"] == nil {
|
|
return errors.New("input or messages is required")
|
|
}
|
|
case "images.generations", "images.edits":
|
|
if strings.TrimSpace(stringFromMap(body, "prompt")) == "" {
|
|
return errors.New("prompt is required")
|
|
}
|
|
}
|
|
return nil
|
|
}
|