easyai-ai-gateway/apps/api/internal/runner/service.go

296 lines
10 KiB
Go

package runner
import (
"context"
"errors"
"log/slog"
"net/http"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type Service struct {
cfg config.Config
store *store.Store
logger *slog.Logger
clients map[string]clients.Client
}
type Result struct {
Task store.GatewayTask
Output map[string]any
}
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
httpClient := &http.Client{Timeout: 120 * time.Second}
return &Service{
cfg: cfg,
store: db,
logger: logger,
clients: map[string]clients.Client{
"openai": clients.OpenAIClient{HTTPClient: httpClient},
"gemini": clients.GeminiClient{HTTPClient: httpClient},
"simulation": clients.SimulationClient{},
},
}
}
func (s *Service) Execute(ctx context.Context, task store.GatewayTask, user *auth.User) (Result, error) {
modelType := modelTypeFromKind(task.Kind)
body := normalizeRequest(task.Kind, task.Request)
if err := validateRequest(task.Kind, body); err != nil {
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation")
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, err
}
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType)
if err != nil {
failed, finishErr := s.failTask(ctx, task.ID, "no_model_candidate", err.Error(), task.RunMode == "simulation")
if finishErr != nil {
return Result{}, finishErr
}
return Result{Task: failed, Output: failed.Result}, err
}
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
return Result{}, err
}
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
return Result{}, err
}
maxAttempts := maxAttemptsForCandidates(candidates)
var lastErr error
for index, candidate := range candidates {
if index >= maxAttempts {
break
}
attemptNo := index + 1
response, err := s.runCandidate(ctx, task, user, body, candidate, attemptNo)
if err == nil {
billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate))
finished, finishErr := s.store.FinishTaskSuccess(ctx, task.ID, response.Result, billings)
if finishErr != nil {
return Result{}, finishErr
}
if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{"result": response.Result, "billings": billings}, isSimulation(task, candidate)); err != nil {
return Result{}, err
}
return Result{Task: finished, Output: response.Result}, nil
}
lastErr = err
retryable := clients.IsRetryable(err)
if !retryable || !retryEnabled(candidate) || attemptNo >= maxAttempts {
break
}
if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.55, "retrying next client", map[string]any{"attempt": attemptNo, "error": err.Error()}, isSimulation(task, candidate)); err != nil {
return Result{}, err
}
}
code := clients.ErrorCode(lastErr)
message := "task failed"
if lastErr != nil {
message = lastErr.Error()
}
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation")
if err != nil {
return Result{}, err
}
return Result{Task: failed, Output: failed.Result}, lastErr
}
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int) (clients.Response, error) {
simulated := isSimulation(task, candidate)
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
return clients.Response{}, err
}
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
TaskID: task.ID,
AttemptNo: attemptNo,
PlatformID: candidate.PlatformID,
PlatformModelID: candidate.PlatformModelID,
ClientID: candidate.ClientID,
QueueKey: candidate.QueueKey,
Status: "running",
Simulated: simulated,
RequestSnapshot: body,
})
if err != nil {
return clients.Response{}, err
}
reservations := s.rateLimitReservations(ctx, user, candidate, body)
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, reservations)
if err != nil {
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{AttemptID: attemptID, Status: "failed", Retryable: false, ErrorCode: "rate_limit", ErrorMessage: err.Error()})
return clients.Response{}, &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false}
}
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
return clients.Response{}, err
}
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
client := s.clientFor(candidate, simulated)
response, err := client.Run(ctx, clients.Request{
Kind: task.Kind,
ModelType: candidate.ModelType,
Model: task.Model,
Body: body,
Candidate: candidate,
Stream: boolFromMap(body, "stream"),
})
if err != nil {
retryable := clients.IsRetryable(err)
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: retryable,
ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(),
})
_ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable}, simulated)
return clients.Response{}, err
}
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
if err != nil {
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "failed",
Retryable: clients.IsRetryable(err),
ErrorCode: clients.ErrorCode(err),
ErrorMessage: err.Error(),
})
return clients.Response{}, err
}
response.Result = uploadedResult
for _, progress := range response.Progress {
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
return clients.Response{}, err
}
}
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
AttemptID: attemptID,
Status: "succeeded",
ResponseSnapshot: response.Result,
}); err != nil {
return clients.Response{}, err
}
return response, nil
}
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
if simulated {
return s.clients["simulation"]
}
key := strings.ToLower(candidate.Provider)
if client, ok := s.clients[key]; ok {
return client
}
return s.clients["openai"]
}
func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool) (store.GatewayTask, error) {
failed, err := s.store.FinishTaskFailure(ctx, taskID, code, message)
if err != nil {
return store.GatewayTask{}, err
}
if eventErr := s.emit(ctx, taskID, "task.failed", "failed", "failed", 1, message, map[string]any{"code": code}, simulated); eventErr != nil {
return store.GatewayTask{}, eventErr
}
return failed, nil
}
func (s *Service) emit(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) error {
event, err := s.store.AddTaskEvent(ctx, taskID, eventType, status, phase, progress, message, payload, simulated)
if err != nil {
return err
}
if s.cfg.TaskProgressCallbackEnabled {
return s.store.QueueTaskCallback(ctx, event, s.cfg.TaskProgressCallbackURL)
}
return nil
}
func modelTypeFromKind(kind string) string {
switch kind {
case "chat.completions", "responses":
return "chat"
case "images.generations", "images.edits":
return "image"
default:
return "task"
}
}
func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) bool {
if task.RunMode == "simulation" {
return true
}
return stringFromMap(candidate.Credentials, "mode") == "simulation" || boolFromMap(candidate.PlatformConfig, "testMode")
}
func retryEnabled(candidate store.RuntimeModelCandidate) bool {
if enabled, ok := candidate.ModelRetryPolicy["enabled"].(bool); ok {
return enabled
}
if enabled, ok := candidate.PlatformRetryPolicy["enabled"].(bool); ok {
return enabled
}
return true
}
func maxAttemptsForCandidates(candidates []store.RuntimeModelCandidate) int {
if len(candidates) == 0 {
return 0
}
maxAttempts := len(candidates)
for _, candidate := range candidates {
if value := intFromPolicy(candidate.ModelRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
maxAttempts = value
}
if value := intFromPolicy(candidate.PlatformRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
maxAttempts = value
}
}
if maxAttempts <= 0 {
return 1
}
return maxAttempts
}
func normalizeRequest(kind string, body map[string]any) map[string]any {
out := map[string]any{}
for key, value := range body {
out[key] = value
}
if kind == "responses" && out["messages"] == nil && out["input"] != nil {
out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}}
}
return out
}
func validateRequest(kind string, body map[string]any) error {
switch kind {
case "chat.completions":
if body["messages"] == nil {
return errors.New("messages is required")
}
case "responses":
if body["input"] == nil && body["messages"] == nil {
return errors.New("input or messages is required")
}
case "images.generations", "images.edits":
if strings.TrimSpace(stringFromMap(body, "prompt")) == "" {
return errors.New("prompt is required")
}
}
return nil
}