296 lines
10 KiB
Go
296 lines
10 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/clients"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
)
|
|
|
|
type Service struct {
|
|
cfg config.Config
|
|
store *store.Store
|
|
logger *slog.Logger
|
|
clients map[string]clients.Client
|
|
}
|
|
|
|
type Result struct {
|
|
Task store.GatewayTask
|
|
Output map[string]any
|
|
}
|
|
|
|
func New(cfg config.Config, db *store.Store, logger *slog.Logger) *Service {
|
|
httpClient := &http.Client{Timeout: 120 * time.Second}
|
|
return &Service{
|
|
cfg: cfg,
|
|
store: db,
|
|
logger: logger,
|
|
clients: map[string]clients.Client{
|
|
"openai": clients.OpenAIClient{HTTPClient: httpClient},
|
|
"gemini": clients.GeminiClient{HTTPClient: httpClient},
|
|
"simulation": clients.SimulationClient{},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (s *Service) Execute(ctx context.Context, task store.GatewayTask, user *auth.User) (Result, error) {
|
|
modelType := modelTypeFromKind(task.Kind)
|
|
body := normalizeRequest(task.Kind, task.Request)
|
|
if err := validateRequest(task.Kind, body); err != nil {
|
|
failed, finishErr := s.failTask(ctx, task.ID, "bad_request", err.Error(), task.RunMode == "simulation")
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, err
|
|
}
|
|
candidates, err := s.store.ListModelCandidates(ctx, task.Model, modelType)
|
|
if err != nil {
|
|
failed, finishErr := s.failTask(ctx, task.ID, "no_model_candidate", err.Error(), task.RunMode == "simulation")
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, err
|
|
}
|
|
if err := s.store.MarkTaskRunning(ctx, task.ID, modelType, body); err != nil {
|
|
return Result{}, err
|
|
}
|
|
if err := s.emit(ctx, task.ID, "task.progress", "running", "normalizing", 0.15, "request normalized", map[string]any{"modelType": modelType}, task.RunMode == "simulation"); err != nil {
|
|
return Result{}, err
|
|
}
|
|
|
|
maxAttempts := maxAttemptsForCandidates(candidates)
|
|
var lastErr error
|
|
for index, candidate := range candidates {
|
|
if index >= maxAttempts {
|
|
break
|
|
}
|
|
attemptNo := index + 1
|
|
response, err := s.runCandidate(ctx, task, user, body, candidate, attemptNo)
|
|
if err == nil {
|
|
billings := s.billings(ctx, user, task.Kind, body, candidate, response, isSimulation(task, candidate))
|
|
finished, finishErr := s.store.FinishTaskSuccess(ctx, task.ID, response.Result, billings)
|
|
if finishErr != nil {
|
|
return Result{}, finishErr
|
|
}
|
|
if err := s.emit(ctx, task.ID, "task.completed", "succeeded", "completed", 1, "task completed", map[string]any{"result": response.Result, "billings": billings}, isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
return Result{Task: finished, Output: response.Result}, nil
|
|
}
|
|
lastErr = err
|
|
retryable := clients.IsRetryable(err)
|
|
if !retryable || !retryEnabled(candidate) || attemptNo >= maxAttempts {
|
|
break
|
|
}
|
|
if err := s.emit(ctx, task.ID, "task.retrying", "running", "retry", 0.55, "retrying next client", map[string]any{"attempt": attemptNo, "error": err.Error()}, isSimulation(task, candidate)); err != nil {
|
|
return Result{}, err
|
|
}
|
|
}
|
|
code := clients.ErrorCode(lastErr)
|
|
message := "task failed"
|
|
if lastErr != nil {
|
|
message = lastErr.Error()
|
|
}
|
|
failed, err := s.failTask(ctx, task.ID, code, message, task.RunMode == "simulation")
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
return Result{Task: failed, Output: failed.Result}, lastErr
|
|
}
|
|
|
|
func (s *Service) runCandidate(ctx context.Context, task store.GatewayTask, user *auth.User, body map[string]any, candidate store.RuntimeModelCandidate, attemptNo int) (clients.Response, error) {
|
|
simulated := isSimulation(task, candidate)
|
|
if err := s.emit(ctx, task.ID, "task.attempt.started", "running", "submitting", 0.25, "client attempt started", map[string]any{"attempt": attemptNo, "clientId": candidate.ClientID}, simulated); err != nil {
|
|
return clients.Response{}, err
|
|
}
|
|
attemptID, err := s.store.CreateTaskAttempt(ctx, store.CreateTaskAttemptInput{
|
|
TaskID: task.ID,
|
|
AttemptNo: attemptNo,
|
|
PlatformID: candidate.PlatformID,
|
|
PlatformModelID: candidate.PlatformModelID,
|
|
ClientID: candidate.ClientID,
|
|
QueueKey: candidate.QueueKey,
|
|
Status: "running",
|
|
Simulated: simulated,
|
|
RequestSnapshot: body,
|
|
})
|
|
if err != nil {
|
|
return clients.Response{}, err
|
|
}
|
|
reservations := s.rateLimitReservations(ctx, user, candidate, body)
|
|
limitResult, err := s.store.ReserveRateLimits(ctx, task.ID, reservations)
|
|
if err != nil {
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{AttemptID: attemptID, Status: "failed", Retryable: false, ErrorCode: "rate_limit", ErrorMessage: err.Error()})
|
|
return clients.Response{}, &clients.ClientError{Code: "rate_limit", Message: err.Error(), Retryable: false}
|
|
}
|
|
defer s.store.ReleaseConcurrencyLeases(context.WithoutCancel(ctx), limitResult.LeaseIDs)
|
|
|
|
if err := s.store.RecordClientAssignment(ctx, candidate); err != nil {
|
|
return clients.Response{}, err
|
|
}
|
|
defer s.store.RecordClientRelease(context.WithoutCancel(ctx), candidate.ClientID, "")
|
|
|
|
client := s.clientFor(candidate, simulated)
|
|
response, err := client.Run(ctx, clients.Request{
|
|
Kind: task.Kind,
|
|
ModelType: candidate.ModelType,
|
|
Model: task.Model,
|
|
Body: body,
|
|
Candidate: candidate,
|
|
Stream: boolFromMap(body, "stream"),
|
|
})
|
|
if err != nil {
|
|
retryable := clients.IsRetryable(err)
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "failed",
|
|
Retryable: retryable,
|
|
ErrorCode: clients.ErrorCode(err),
|
|
ErrorMessage: err.Error(),
|
|
})
|
|
_ = s.emit(ctx, task.ID, "task.attempt.failed", "running", "attempt_failed", 0.45, err.Error(), map[string]any{"attempt": attemptNo, "retryable": retryable}, simulated)
|
|
return clients.Response{}, err
|
|
}
|
|
uploadedResult, err := s.uploadGeneratedAssets(ctx, response.Result)
|
|
if err != nil {
|
|
_ = s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "failed",
|
|
Retryable: clients.IsRetryable(err),
|
|
ErrorCode: clients.ErrorCode(err),
|
|
ErrorMessage: err.Error(),
|
|
})
|
|
return clients.Response{}, err
|
|
}
|
|
response.Result = uploadedResult
|
|
for _, progress := range response.Progress {
|
|
if err := s.emit(ctx, task.ID, "task.progress", "running", progress.Phase, progress.Progress, progress.Message, progress.Payload, simulated); err != nil {
|
|
return clients.Response{}, err
|
|
}
|
|
}
|
|
if err := s.store.FinishTaskAttempt(ctx, store.FinishTaskAttemptInput{
|
|
AttemptID: attemptID,
|
|
Status: "succeeded",
|
|
ResponseSnapshot: response.Result,
|
|
}); err != nil {
|
|
return clients.Response{}, err
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
func (s *Service) clientFor(candidate store.RuntimeModelCandidate, simulated bool) clients.Client {
|
|
if simulated {
|
|
return s.clients["simulation"]
|
|
}
|
|
key := strings.ToLower(candidate.Provider)
|
|
if client, ok := s.clients[key]; ok {
|
|
return client
|
|
}
|
|
return s.clients["openai"]
|
|
}
|
|
|
|
func (s *Service) failTask(ctx context.Context, taskID string, code string, message string, simulated bool) (store.GatewayTask, error) {
|
|
failed, err := s.store.FinishTaskFailure(ctx, taskID, code, message)
|
|
if err != nil {
|
|
return store.GatewayTask{}, err
|
|
}
|
|
if eventErr := s.emit(ctx, taskID, "task.failed", "failed", "failed", 1, message, map[string]any{"code": code}, simulated); eventErr != nil {
|
|
return store.GatewayTask{}, eventErr
|
|
}
|
|
return failed, nil
|
|
}
|
|
|
|
func (s *Service) emit(ctx context.Context, taskID string, eventType string, status string, phase string, progress float64, message string, payload map[string]any, simulated bool) error {
|
|
event, err := s.store.AddTaskEvent(ctx, taskID, eventType, status, phase, progress, message, payload, simulated)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if s.cfg.TaskProgressCallbackEnabled {
|
|
return s.store.QueueTaskCallback(ctx, event, s.cfg.TaskProgressCallbackURL)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func modelTypeFromKind(kind string) string {
|
|
switch kind {
|
|
case "chat.completions", "responses":
|
|
return "chat"
|
|
case "images.generations", "images.edits":
|
|
return "image"
|
|
default:
|
|
return "task"
|
|
}
|
|
}
|
|
|
|
func isSimulation(task store.GatewayTask, candidate store.RuntimeModelCandidate) bool {
|
|
if task.RunMode == "simulation" {
|
|
return true
|
|
}
|
|
return stringFromMap(candidate.Credentials, "mode") == "simulation" || boolFromMap(candidate.PlatformConfig, "testMode")
|
|
}
|
|
|
|
func retryEnabled(candidate store.RuntimeModelCandidate) bool {
|
|
if enabled, ok := candidate.ModelRetryPolicy["enabled"].(bool); ok {
|
|
return enabled
|
|
}
|
|
if enabled, ok := candidate.PlatformRetryPolicy["enabled"].(bool); ok {
|
|
return enabled
|
|
}
|
|
return true
|
|
}
|
|
|
|
func maxAttemptsForCandidates(candidates []store.RuntimeModelCandidate) int {
|
|
if len(candidates) == 0 {
|
|
return 0
|
|
}
|
|
maxAttempts := len(candidates)
|
|
for _, candidate := range candidates {
|
|
if value := intFromPolicy(candidate.ModelRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
|
|
maxAttempts = value
|
|
}
|
|
if value := intFromPolicy(candidate.PlatformRetryPolicy, "maxAttempts"); value > 0 && value < maxAttempts {
|
|
maxAttempts = value
|
|
}
|
|
}
|
|
if maxAttempts <= 0 {
|
|
return 1
|
|
}
|
|
return maxAttempts
|
|
}
|
|
|
|
func normalizeRequest(kind string, body map[string]any) map[string]any {
|
|
out := map[string]any{}
|
|
for key, value := range body {
|
|
out[key] = value
|
|
}
|
|
if kind == "responses" && out["messages"] == nil && out["input"] != nil {
|
|
out["messages"] = []any{map[string]any{"role": "user", "content": out["input"]}}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func validateRequest(kind string, body map[string]any) error {
|
|
switch kind {
|
|
case "chat.completions":
|
|
if body["messages"] == nil {
|
|
return errors.New("messages is required")
|
|
}
|
|
case "responses":
|
|
if body["input"] == nil && body["messages"] == nil {
|
|
return errors.New("input or messages is required")
|
|
}
|
|
case "images.generations", "images.edits":
|
|
if strings.TrimSpace(stringFromMap(body, "prompt")) == "" {
|
|
return errors.New("prompt is required")
|
|
}
|
|
}
|
|
return nil
|
|
}
|