package runner import ( "context" "errors" "log/slog" "strconv" "strings" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) type Service struct { cfg config.Config store *store.Store logger *slog.Logger clients map[string]clients.Client httpClients *httpClientCache } type Result struct { Task store.GatewayTask Output map[string]any } func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service { httpClients := newHTTPClientCache() return &Service{ cfg: cfg, store: db, logger: logger, clients: map[string]clients.Client{ "openai": clients.OpenAIClient{HTTPClient: httpClients.none}, "gemini": clients.GeminiClient{HTTPClient: httpClients.none}, "volces": clients.VolcesClient{HTTPClient: httpClients.none}, "simulation": clients.SimulationClient{}, }, httpClients: httpClients, } } func (s *Service) Execute(ctx context.Context, task store.GatewayTask, user *auth.User) (Result, error) { return s.execute(ctx, task, user, nil) } func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) { return s.execute(ctx, task, user, onDelta) } func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) { executeStartedAt := time.Now() body := normalizeRequest(task.Kind, task.Request) modelType := modelTypeFromKind(task.Kind, body) if err := validateRequest(task.Kind, body); err != nil { failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err) if finishErr != nil { return Result{}, finishErr } return Result{Task: failed, Output: failed.Result}, err } candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user) if err != nil { failed, finishErr := s.failTask(ctx, task.ID, store.ModelCandidateErrorCode(err), err.Error(), task.RunMode == "simulation", err) if finishErr != nil { return Result{}, finishErr } return Result{Task: failed, Output: failed.Result}, err } if len(candidates) > 0 { estimatedBillings := s.estimatedBillings(ctx, user, task.Kind, body, candidates[0]) if err := s.ensureWalletBalance(ctx, user, estimatedBillings); err != nil { if errors.Is(err, store.ErrInsufficientWalletBalance) { failed, finishErr := s.failTask(ctx, task.ID, "insufficient_balance", err.Error(), task.RunMode == "simulation", err) if finishErr != nil { return Result{}, finishErr } return Result{Task: failed, Output: failed.Result}, err } return Result{}, err } } if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil { return Result{}, err } if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil { return Result{}, err } runnerPolicy, err := s.store.GetActiveRunnerPolicy(ctx) if err != nil { return Result{}, err } maxPlatforms := maxPlatformsForCandidates(candidates, runnerPolicy) maxFailoverDuration := maxFailoverDurationForCandidates(candidates, runnerPolicy) attemptNo := 0 var lastErr error for index, candidate := range candidates { if index >= maxPlatforms { break } clientAttempts := clientAttemptsForCandidate(candidate) var candidateErr error for clientAttempt := 1; clientAttempt <= clientAttempts; clientAttempt++ { attemptNo++ response, err := s.runCandidate(ctx, task, user, body, candidate, attemptNo, onDelta) if err == nil { billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate)) record := buildSuccessRecord(task, user, body, candidate, response, billings, isSimulation(task, candidate)) record.Metrics = s.withAttemptHistory(ctx, task.ID, record.Metrics) finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{ TaskID: task.ID, Result: response.Result, Billings: billings, RequestID: record.RequestID, ResolvedModel: record.ResolvedModel, Usage: record.Usage, Metrics: record.Metrics, BillingSummary: record.BillingSummary, FinalChargeAmount: record.FinalChargeAmount, ResponseStartedAt: record.ResponseStartedAt, ResponseFinishedAt: record.ResponseFinishedAt, ResponseDurationMS: record.ResponseDurationMS, }) if finishErr != nil { return Result{}, finishErr } if settleErr := s.store.SettleTaskBilling(ctx, finished); settleErr != nil { return Result{}, settleErr } if finished.FinalChargeAmount > 0 { if err := s.emit(ctx, task.ID, "task.billing.settled", "succeeded", "billing", 0.98, "task billing settled", map[string]any{ "amount": finished.FinalChargeAmount, "currency": stringFromAny(record.BillingSummary["currency"]), }, isSimulation(task, candidate)); err != nil { return Result{}, err } } if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{ "result": response.Result, "billings": billings, "usage": record.Usage, "metrics": record.Metrics, "billingSummary": record.BillingSummary, "requestId": record.RequestID, }, isSimulation(task, candidate)); err != nil { return Result{}, err } return Result{Task: finished, Output: response.Result}, nil } lastErr = err candidateErr = err retryDecision := retryDecisionForCandidate(candidate, err) retryAction := "retry" if !retryDecision.Retry { retryAction = "stop" } if clientAttempt >= clientAttempts { retryDecision.Retry = false retryDecision.Reason = "same_client_max_attempts" retryDecision.Match = policyRuleMatch{ Source: "model_runtime_policy_sets.retry_policy", Policy: "retryPolicy", Rule: "maxAttempts", Value: strconv.Itoa(clientAttempts), } retryDecision.Info = failureInfoFromError(err) retryAction = "stop" } if failoverTimeBudgetExceeded(executeStartedAt, maxFailoverDuration) { retryDecision.Retry = false retryDecision.Reason = "failover_time_budget_exceeded" retryDecision.Match = policyRuleMatch{ Source: "gateway_runner_policies.failover_policy", Policy: "failoverPolicy", Rule: "maxDurationSeconds", Value: strconv.Itoa(int(maxFailoverDuration.Seconds())), } retryDecision.Info = failureInfoFromError(err) retryAction = "stop" } s.recordAttemptTrace(ctx, task.ID, attemptNo, retryTraceEntry(retryDecision, retryAction, clientAttempt, clientAttempts)) if !retryDecision.Retry { break } if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.45, "retrying same client", addPolicyTracePayload(map[string]any{ "attempt": attemptNo, "clientAttempt": clientAttempt, "clientId": candidate.ClientID, "error": err.Error(), "reason": retryDecision.Reason, "scope": "same_client", }, retryDecision.Match, retryDecision.Info), isSimulation(task, candidate)); err != nil { return Result{}, err } } if candidateErr == nil || index+1 >= len(candidates) || index+1 >= maxPlatforms { if candidateErr != nil { s.applyPriorityDemotePolicy(ctx, task.ID, attemptNo, runnerPolicy, candidate, candidateErr, isSimulation(task, candidate)) decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr) if decision.Retry { decision.Retry = false decision.Action = "stop" decision.Reason = "no_next_platform" decision.Match = policyRuleMatch{Source: "runner_candidates", Policy: "candidateSelection", Rule: "candidateCount", Value: strconv.Itoa(len(candidates))} if index+1 >= maxPlatforms { decision.Reason = "max_platforms_reached" decision.Match = policyRuleMatch{Source: "gateway_runner_policies.failover_policy", Policy: "failoverPolicy", Rule: "maxPlatforms", Value: strconv.Itoa(maxPlatforms)} } } s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision)) } break } s.applyPriorityDemotePolicy(ctx, task.ID, attemptNo, runnerPolicy, candidate, candidateErr, isSimulation(task, candidate)) if failoverTimeBudgetExceeded(executeStartedAt, maxFailoverDuration) { elapsedSeconds := int(time.Since(executeStartedAt).Seconds()) maxDurationSeconds := int(maxFailoverDuration.Seconds()) s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTimeBudgetTraceEntry(elapsedSeconds, maxDurationSeconds, failureInfoFromError(candidateErr))) if err := s.emit(ctx, task.ID, "task.failover.stopped", "running", "retry", 0.55, "failover time budget exceeded", map[string]any{ "elapsedSeconds": elapsedSeconds, "maxDurationSeconds": maxDurationSeconds, "scope": "next_platform", "statusCode": clients.ErrorResponseMetadata(candidateErr).StatusCode, }, isSimulation(task, candidate)); err != nil { return Result{}, err } break } decision := failoverDecisionForCandidate(runnerPolicy, candidate, candidateErr) s.recordAttemptTrace(ctx, task.ID, attemptNo, failoverTraceEntry(decision)) if !decision.Retry { break } s.applyFailoverAction(ctx, task.ID, candidate, decision, isSimulation(task, candidate)) if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.55, "retrying next client", addPolicyTracePayload(map[string]any{ "attempt": attemptNo, "action": decision.Action, "error": candidateErr.Error(), "reason": decision.Reason, "scope": "next_platform", }, decision.Match, decision.Info), isSimulation(task, candidate)); err != nil { return Result{}, err } } code := clients.ErrorCode(lastErr) message := "task failed" if lastErr != nil { message = lastErr.Error() } failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr) if err != nil { return Result{}, err } return Result{Task: failed, Output: failed.Result}, lastErr } func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) { simulated := isSimulation(task, candidate) if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil { return clients.Response{}, err } attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{ TaskID: task.ID, AttemptNo: attemptNo, PlatformID: candidate.PlatformID, PlatformModelID: candidate.PlatformModelID, ClientID: candidate.ClientID, QueueKey: candidate.QueueKey, Status: "running", Simulated: simulated, RequestSnapshot: body, Metrics: attemptMetrics(candidate, attemptNo, simulated), }) if err != nil { return clients.Response{}, err } reservations := s.rateLimitReservations(ctx, user, candidate, body) limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, attemptID, reservations) if err != nil { clientErr := &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false} _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: false, Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(clientErr, false)}}), ErrorCode: "rate_limit", ErrorMessage: err.Error(), }) return clients.Response{}, clientErr } rateReservationsFinalized := false defer func() { if !rateReservationsFinalized { _ = s.store.ReleaseRateLimitReservations(context.WithoutCancel(ctx), limitResult.Reservations, "attempt_failed") } }() defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs) if err := s.store.RecordClientAssignment(ctx, candidate); err != nil { return clients.Response{}, err } defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "") requestHTTPClient, err := s.httpClientForCandidate(candidate, simulated) if err != nil { _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: false, Metrics: mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), map[string]any{"error": err.Error(), "retryable": false, "trace": []any{failureTraceEntry(err, false)}}), ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) return clients.Response{}, err } client := s.clientFor(candidate, simulated) callStartedAt := time.Now() response, err := client.Run(ctx, clients.Request{ Kind: task.Kind, ModelType: candidate.ModelType, Model: task.Model, Body: body, Candidate: candidate, HTTPClient: requestHTTPClient, Stream: boolFromMap(body, "stream"), StreamDelta: onDelta, }) callFinishedAt := time.Now() if response.ResponseStartedAt.IsZero() { response.ResponseStartedAt = callStartedAt } if response.ResponseFinishedAt.IsZero() { response.ResponseFinishedAt = callFinishedAt } if response.ResponseDurationMS == 0 { response.ResponseDurationMS = response.ResponseFinishedAt.Sub(response.ResponseStartedAt).Milliseconds() if response.ResponseDurationMS < 0 { response.ResponseDurationMS = 0 } } if err != nil { retryable := clients.IsRetryable(err) requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(err, simulated) if responseStartedAt.IsZero() { responseStartedAt = callStartedAt } if responseFinishedAt.IsZero() { responseFinishedAt = callFinishedAt } if responseDurationMS == 0 { responseDurationMS = responseFinishedAt.Sub(responseStartedAt).Milliseconds() if responseDurationMS < 0 { responseDurationMS = 0 } } metrics = mergeMetrics(attemptMetrics(candidate, attemptNo, simulated), metrics) _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: retryable, RequestID: requestID, Metrics: metrics, ResponseStartedAt: responseStartedAt, ResponseFinishedAt: responseFinishedAt, ResponseDurationMS: responseDurationMS, ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) _ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "statusCode": clients.ErrorResponseMetadata(err).StatusCode, "metrics": metrics}, simulated) s.applyCandidateFailurePolicies(ctx, task.ID, candidate, err, simulated) return clients.Response{}, err } uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result) if err != nil { metrics := mergeMetrics(taskMetrics(task, user, body, candidate, response, simulated), map[string]any{ "error": err.Error(), "retryable": clients.IsRetryable(err), "trace": []any{failureTraceEntry(err, clients.IsRetryable(err))}, }) _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: clients.IsRetryable(err), RequestID: response.RequestID, Usage: usageToMap(response.Usage), Metrics: metrics, ResponseSnapshot: response.Result, ResponseStartedAt: response.ResponseStartedAt, ResponseFinishedAt: response.ResponseFinishedAt, ResponseDurationMS: response.ResponseDurationMS, ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) return clients.Response{}, err } response.Result = uploadedResult for _, progress := range response.Progress { if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil { return clients.Response{}, err } } if err := s.store.CommitRateLimitReservations(ctx, limitResult.Reservations, tokenUsageAmounts(response.Usage)); err != nil { return clients.Response{}, err } rateReservationsFinalized = true if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "succeeded", RequestID: response.RequestID, Usage: usageToMap(response.Usage), Metrics: taskMetrics(task, user, body, candidate, response, simulated), ResponseSnapshot: response.Result, ResponseStartedAt: response.ResponseStartedAt, ResponseFinishedAt: response.ResponseFinishedAt, ResponseDurationMS: response.ResponseDurationMS, }); err != nil { return clients.Response{}, err } return response, nil } func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client { if simulated { return s.clients["simulation"] } key := strings.ToLower(strings.TrimSpace(candidate.SpecType)) if key == "" { key = strings.ToLower(strings.TrimSpace(candidate.Provider)) } if client, ok := s.clients[key]; ok { return client } return s.clients["openai"] } func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool, cause error) (store.GatewayTask, error) { requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(cause, simulated) metrics = s.withAttemptHistory(ctx, taskID, metrics) failed, err := s.store.FinishTaskFailure(ctx, store.FinishTaskFailureInput{ TaskID: taskID, Code: code, Message: message, RequestID: requestID, Metrics: metrics, ResponseStartedAt: responseStartedAt, ResponseFinishedAt: responseFinishedAt, ResponseDurationMS: responseDurationMS, }) if err != nil { return store.GatewayTask{}, err } if eventErr := s.emit(ctx, taskID, "task.failed", "failed", "failed", 1, message, map[string]any{"code": code, "requestId": requestID, "metrics": metrics}, simulated); eventErr != nil { return store.GatewayTask{}, eventErr } return failed, nil } func (s *Service) withAttemptHistory(ctx context.Context, taskID string, metrics map[string]any) map[string]any { attempts, err := s.store.ListTaskAttempts(ctx, taskID) if err != nil { s.logger.Warn("list task attempts for metrics failed", "taskID", taskID, "error", err) return metrics } if len(attempts) == 0 { return metrics } metrics = mergeMetrics(metrics) metrics["attemptCount"] = len(attempts) metrics["attempts"] = summarizeAttempts(attempts) return metrics } func (s *Service) emit(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) error { event, err := s.store.AddTaskEvent(ctx, taskID, eventType, status, phase, progress, message, payload, simulated) if err != nil { return err } if s.cfg.TaskProgressCallbackEnabled { return s.store.QueueTaskCallback(ctx, event, s.cfg.TaskProgressCallbackURL) } return nil } func modelTypeFromKind(kind string, body map[string]any) string { switch kind { case "chat.completions", "responses": return "text_generate" case "images.generations", "images.edits": if kind == "images.edits" { return "image_edit" } return "image_generate" case "videos.generations": if videoRequestHasReferenceImage(body) { return "image_to_video" } return "video_generate" default: return "task" } } func videoRequestHasReferenceImage(body map[string]any) bool { if body == nil { return false } for _, key := range []string{ "image", "images", "image_url", "imageUrl", "image_urls", "imageUrls", "reference_image", "referenceImage", "first_frame", "firstFrame", "last_frame", "lastFrame", } { if hasAnyString(body, key) { return true } } return false } func isTextGenerationKind(kind string) bool { return kind == "chat.completions" || kind == "responses" } func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) bool { if task.RunMode == "simulation" { return true } return stringFromMap(candidate.Credentials, "mode") == "simulation" || boolFromMap(candidate.PlatformConfig, "testMode") } func retryEnabled(candidate store.RuntimeModelCandidate) bool { policy := effectiveRetryPolicy(candidate) if enabled, ok := policy["enabled"].(bool); ok { return enabled } return true } func clientAttemptsForCandidate(candidate store.RuntimeModelCandidate) int { if !retryEnabled(candidate) { return 1 } if value := intFromPolicy(effectiveRetryPolicy(candidate), "maxAttempts"); value > 0 { return value } return 1 } func maxPlatformsForCandidates(candidates []store.RuntimeModelCandidate, runnerPolicy store.RunnerPolicy) int { if len(candidates) == 0 { return 0 } maxPlatforms := len(candidates) if value := intFromPolicy(runnerPolicy.FailoverPolicy, "maxPlatforms"); value > 0 { if value < maxPlatforms { maxPlatforms = value } } else if maxPlatforms > 99 { maxPlatforms = 99 } for _, candidate := range candidates { if value := intFromPolicy(effectiveFailoverPolicy(runnerPolicy.FailoverPolicy, candidate.RuntimePolicyOverride), "maxPlatforms"); value > 0 && value < maxPlatforms { maxPlatforms = value } } if maxPlatforms <= 0 { return 1 } return maxPlatforms } func maxFailoverDurationForCandidates(candidates []store.RuntimeModelCandidate, runnerPolicy store.RunnerPolicy) time.Duration { seconds := intFromPolicy(runnerPolicy.FailoverPolicy, "maxDurationSeconds") if seconds <= 0 { seconds = 600 } for _, candidate := range candidates { if value := intFromPolicy(effectiveFailoverPolicy(runnerPolicy.FailoverPolicy, candidate.RuntimePolicyOverride), "maxDurationSeconds"); value > 0 && value < seconds { seconds = value } } return time.Duration(seconds) * time.Second } func failoverTimeBudgetExceeded(start time.Time, maxDuration time.Duration) bool { return maxDuration > 0 && time.Since(start) >= maxDuration } func normalizeRequest(kind string, body map[string]any) map[string]any { out := map[string]any{} for key, value := range body { out[key] = value } if kind == "responses" && out["messages"] == nil && out["input"] != nil { out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}} } return out } func validateRequest(kind string, body map[string]any) error { switch kind { case "chat.completions": if body["messages"] == nil { return errors.New("messages is required") } case "responses": if body["input"] == nil && body["messages"] == nil { return errors.New("input or messages is required") } case "images.generations", "images.edits": if strings.TrimSpace(stringFromMap(body, "prompt")) == "" { return errors.New("prompt is required") } } return nil }