package runner import ( "context" "errors" "log/slog" "net/http" "strings" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/clients" "github.com/easyai/easyai-ai-gateway/apps/api/internal/config" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) type Service struct { cfg config.Config store *store.Store logger *slog.Logger clients map[string]clients.Client } type Result struct { Task store.GatewayTask Output map[string]any } func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service { httpClient := &http.Client{Timeout: 120 * time.Second} return &Service{ cfg: cfg, store: db, logger: logger, clients: map[string]clients.Client{ "openai": clients.OpenAIClient{HTTPClient: httpClient}, "gemini": clients.GeminiClient{HTTPClient: httpClient}, "volces": clients.VolcesClient{HTTPClient: httpClient}, "simulation": clients.SimulationClient{}, }, } } func (s *Service) Execute(ctx context.Context, task store.GatewayTask, user *auth.User) (Result, error) { return s.execute(ctx, task, user, nil) } func (s *Service) ExecuteStream(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) { return s.execute(ctx, task, user, onDelta) } func (s *Service) execute(ctx context.Context, task store.GatewayTask, user *auth.User, onDelta clients.StreamDelta) (Result, error) { modelType := modelTypeFromKind(task.Kind) body := normalizeRequest(task.Kind, task.Request) if err := validateRequest(task.Kind, body); err != nil { failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation", err) if finishErr != nil { return Result{}, finishErr } return Result{Task: failed, Output: failed.Result}, err } candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType, user) if err != nil { failed, finishErr := s.failTask(ctx, task.ID, "no_model_candidate", err.Error(), task.RunMode == "simulation", err) if finishErr != nil { return Result{}, finishErr } return Result{Task: failed, Output: failed.Result}, err } if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil { return Result{}, err } if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil { return Result{}, err } maxAttempts := maxAttemptsForCandidates(candidates) var lastErr error for index, candidate := range candidates { if index >= maxAttempts { break } attemptNo := index + 1 response, err := s.runCandidate(ctx, task, user, body, candidate, attemptNo, onDelta) if err == nil { billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate)) record := buildSuccessRecord(task, user, body, candidate, response, billings, isSimulation(task, candidate)) finished, finishErr := s.store.FinishTaskSuccess(ctx, store.FinishTaskSuccessInput{ TaskID: task.ID, Result: response.Result, Billings: billings, RequestID: record.RequestID, ResolvedModel: record.ResolvedModel, Usage: record.Usage, Metrics: record.Metrics, BillingSummary: record.BillingSummary, FinalChargeAmount: record.FinalChargeAmount, ResponseStartedAt: record.ResponseStartedAt, ResponseFinishedAt: record.ResponseFinishedAt, ResponseDurationMS: record.ResponseDurationMS, }) if finishErr != nil { return Result{}, finishErr } if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{ "result": response.Result, "billings": billings, "usage": record.Usage, "metrics": record.Metrics, "billingSummary": record.BillingSummary, "requestId": record.RequestID, }, isSimulation(task, candidate)); err != nil { return Result{}, err } return Result{Task: finished, Output: response.Result}, nil } lastErr = err retryable := clients.IsRetryable(err) if !retryable || !retryEnabled(candidate) || attemptNo >= maxAttempts { break } if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.55, "retrying next client", map[string]any{"attempt": attemptNo, "error": err.Error()}, isSimulation(task, candidate)); err != nil { return Result{}, err } } code := clients.ErrorCode(lastErr) message := "task failed" if lastErr != nil { message = lastErr.Error() } failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation", lastErr) if err != nil { return Result{}, err } return Result{Task: failed, Output: failed.Result}, lastErr } func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int, onDelta clients.StreamDelta) (clients.Response, error) { simulated := isSimulation(task, candidate) if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil { return clients.Response{}, err } attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{ TaskID: task.ID, AttemptNo: attemptNo, PlatformID: candidate.PlatformID, PlatformModelID: candidate.PlatformModelID, ClientID: candidate.ClientID, QueueKey: candidate.QueueKey, Status: "running", Simulated: simulated, RequestSnapshot: body, }) if err != nil { return clients.Response{}, err } reservations := s.rateLimitReservations(ctx, user, candidate, body) limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, reservations) if err != nil { _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: false, Metrics: map[string]any{"error": err.Error(), "candidateModel": candidate.ModelName, "clientId": candidate.ClientID}, ErrorCode: "rate_limit", ErrorMessage: err.Error(), }) return clients.Response{}, &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false} } defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs) if err := s.store.RecordClientAssignment(ctx, candidate); err != nil { return clients.Response{}, err } defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "") client := s.clientFor(candidate, simulated) callStartedAt := time.Now() response, err := client.Run(ctx, clients.Request{ Kind: task.Kind, ModelType: candidate.ModelType, Model: task.Model, Body: body, Candidate: candidate, Stream: boolFromMap(body, "stream"), StreamDelta: onDelta, }) callFinishedAt := time.Now() if response.ResponseStartedAt.IsZero() { response.ResponseStartedAt = callStartedAt } if response.ResponseFinishedAt.IsZero() { response.ResponseFinishedAt = callFinishedAt } if response.ResponseDurationMS == 0 { response.ResponseDurationMS = response.ResponseFinishedAt.Sub(response.ResponseStartedAt).Milliseconds() if response.ResponseDurationMS < 0 { response.ResponseDurationMS = 0 } } if err != nil { retryable := clients.IsRetryable(err) requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(err, simulated) if responseStartedAt.IsZero() { responseStartedAt = callStartedAt } if responseFinishedAt.IsZero() { responseFinishedAt = callFinishedAt } if responseDurationMS == 0 { responseDurationMS = responseFinishedAt.Sub(responseStartedAt).Milliseconds() if responseDurationMS < 0 { responseDurationMS = 0 } } _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: retryable, RequestID: requestID, Metrics: metrics, ResponseStartedAt: responseStartedAt, ResponseFinishedAt: responseFinishedAt, ResponseDurationMS: responseDurationMS, ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) _ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable, "requestId": requestID, "metrics": metrics}, simulated) return clients.Response{}, err } uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result) if err != nil { metrics := taskMetrics(task, user, body, candidate, response, simulated) metrics["error"] = err.Error() metrics["retryable"] = clients.IsRetryable(err) _ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "failed", Retryable: clients.IsRetryable(err), RequestID: response.RequestID, Usage: usageToMap(response.Usage), Metrics: metrics, ResponseSnapshot: response.Result, ResponseStartedAt: response.ResponseStartedAt, ResponseFinishedAt: response.ResponseFinishedAt, ResponseDurationMS: response.ResponseDurationMS, ErrorCode: clients.ErrorCode(err), ErrorMessage: err.Error(), }) return clients.Response{}, err } response.Result = uploadedResult for _, progress := range response.Progress { if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil { return clients.Response{}, err } } if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{ AttemptID: attemptID, Status: "succeeded", RequestID: response.RequestID, Usage: usageToMap(response.Usage), Metrics: taskMetrics(task, user, body, candidate, response, simulated), ResponseSnapshot: response.Result, ResponseStartedAt: response.ResponseStartedAt, ResponseFinishedAt: response.ResponseFinishedAt, ResponseDurationMS: response.ResponseDurationMS, }); err != nil { return clients.Response{}, err } return response, nil } func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client { if simulated { return s.clients["simulation"] } key := strings.ToLower(strings.TrimSpace(candidate.SpecType)) if key == "" { key = strings.ToLower(strings.TrimSpace(candidate.Provider)) } if client, ok := s.clients[key]; ok { return client } return s.clients["openai"] } func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool, cause error) (store.GatewayTask, error) { requestID, metrics, responseStartedAt, responseFinishedAt, responseDurationMS := failureMetrics(cause, simulated) failed, err := s.store.FinishTaskFailure(ctx, store.FinishTaskFailureInput{ TaskID: taskID, Code: code, Message: message, RequestID: requestID, Metrics: metrics, ResponseStartedAt: responseStartedAt, ResponseFinishedAt: responseFinishedAt, ResponseDurationMS: responseDurationMS, }) if err != nil { return store.GatewayTask{}, err } if eventErr := s.emit(ctx, taskID, "task.failed", "failed", "failed", 1, message, map[string]any{"code": code, "requestId": requestID, "metrics": metrics}, simulated); eventErr != nil { return store.GatewayTask{}, eventErr } return failed, nil } func (s *Service) emit(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) error { event, err := s.store.AddTaskEvent(ctx, taskID, eventType, status, phase, progress, message, payload, simulated) if err != nil { return err } if s.cfg.TaskProgressCallbackEnabled { return s.store.QueueTaskCallback(ctx, event, s.cfg.TaskProgressCallbackURL) } return nil } func modelTypeFromKind(kind string) string { switch kind { case "chat.completions", "responses": return "text_generate" case "images.generations", "images.edits": if kind == "images.edits" { return "image_edit" } return "image_generate" case "videos.generations": return "video_generate" default: return "task" } } func isTextGenerationKind(kind string) bool { return kind == "chat.completions" || kind == "responses" } func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) bool { if task.RunMode == "simulation" { return true } return stringFromMap(candidate.Credentials, "mode") == "simulation" || boolFromMap(candidate.PlatformConfig, "testMode") } func retryEnabled(candidate store.RuntimeModelCandidate) bool { if enabled, ok := candidate.ModelRetryPolicy["enabled"].(bool); ok { return enabled } if enabled, ok := candidate.PlatformRetryPolicy["enabled"].(bool); ok { return enabled } return true } func maxAttemptsForCandidates(candidates []store.RuntimeModelCandidate) int { if len(candidates) == 0 { return 0 } maxAttempts := len(candidates) for _, candidate := range candidates { if value := intFromPolicy(candidate.ModelRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts { maxAttempts = value } if value := intFromPolicy(candidate.PlatformRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts { maxAttempts = value } } if maxAttempts <= 0 { return 1 } return maxAttempts } func normalizeRequest(kind string, body map[string]any) map[string]any { out := map[string]any{} for key, value := range body { out[key] = value } if kind == "responses" && out["messages"] == nil && out["input"] != nil { out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}} } return out } func validateRequest(kind string, body map[string]any) error { switch kind { case "chat.completions": if body["messages"] == nil { return errors.New("messages is required") } case "responses": if body["input"] == nil && body["messages"] == nil { return errors.New("input or messages is required") } case "images.generations", "images.edits": if strings.TrimSpace(stringFromMap(body, "prompt")) == "" { return errors.New("prompt is required") } } return nil }