package runner import ( "context" "errors" "fmt" "os" "strings" "time" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivermigrate" "github.com/riverqueue/river/rivertype" ) const asyncTaskQueueName = "gateway_tasks" type asyncTaskArgs struct { TaskID string `json:"task_id" river:"unique"` } func (asyncTaskArgs) Kind() string { return "gateway_task_run" } type asyncTaskWorker struct { river.WorkerDefaults[asyncTaskArgs] service *Service } func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error { task, err := w.service.store.GetTask(ctx, job.Args.TaskID) if err != nil { return err } if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" { return nil } result, runErr := w.service.Execute(ctx, task, authUserFromTask(task)) if runErr == nil { w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID) return nil } var queuedErr *TaskQueuedError if errors.As(runErr, &queuedErr) { return river.JobSnooze(queuedErr.Delay) } if ctx.Err() != nil { queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task) if queueErr != nil { return queueErr } w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID) return river.JobSnooze(0) } w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID) return nil } func (s *Service) StartAsyncQueueWorker(ctx context.Context) { if err := s.startRiverQueue(ctx); err != nil { s.logger.Error("start river async queue failed", "error", err) panic(err) } } func (s *Service) startRiverQueue(ctx context.Context) error { driver := riverpgxv5.New(s.store.Pool()) migrator, err := rivermigrate.New(driver, nil) if err != nil { return err } if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil { return err } workers := river.NewWorkers() if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil { return err } riverClient, err := river.NewClient(driver, &river.Config{ ID: asyncWorkerID(), JobTimeout: -1, Logger: s.logger, CompletedJobRetentionPeriod: 24 * time.Hour, Queues: map[string]river.QueueConfig{ asyncTaskQueueName: {MaxWorkers: 32}, }, RescueStuckJobsAfter: 30 * time.Second, TestOnly: s.cfg.AppEnv == "test", Workers: workers, }) if err != nil { return err } s.riverClient = riverClient if err := riverClient.Start(ctx); err != nil { return err } if err := s.recoverAsyncRiverJobs(ctx); err != nil { return err } go func() { <-ctx.Done() stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := riverClient.StopAndCancel(stopCtx); err != nil { s.logger.Warn("stop river async queue failed", "error", err) } }() return nil } func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error { if s.riverClient == nil { return errors.New("river async queue is not started") } result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task)) if err != nil { return err } if result.Job != nil { return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID) } return nil } func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) { } func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) { if err := s.EnqueueAsyncTask(ctx, task); err != nil { s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err) } } func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error { items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000) if err != nil { return err } for _, item := range items { task := store.GatewayTask{ID: item.ID} result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task)) if err != nil { return err } if result.Job != nil { if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil { return err } } } if len(items) > 0 { s.logger.Info("river async queue recovered persisted tasks", "count", len(items)) } return nil } func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts { priority := 2 if task.ID == "" { priority = 3 } return &river.InsertOpts{ MaxAttempts: 1000, Priority: priority, Queue: asyncTaskQueueName, Tags: []string{"gateway-task"}, UniqueOpts: river.UniqueOpts{ ByArgs: true, ByQueue: true, ByState: []rivertype.JobState{ rivertype.JobStateAvailable, rivertype.JobStatePending, rivertype.JobStateRetryable, rivertype.JobStateRunning, rivertype.JobStateScheduled, }, }, } } func authUserFromTask(task store.GatewayTask) *auth.User { roles := []string{"user"} if strings.TrimSpace(task.UserID) == "" { roles = nil } return &auth.User{ ID: firstNonEmptyString(task.GatewayUserID, task.UserID), Roles: roles, TenantID: task.TenantID, GatewayTenantID: task.GatewayTenantID, TenantKey: task.TenantKey, Source: firstNonEmptyString(task.UserSource, "gateway"), GatewayUserID: task.GatewayUserID, UserGroupID: task.UserGroupID, UserGroupKey: task.UserGroupKey, APIKeyID: task.APIKeyID, APIKeyName: task.APIKeyName, APIKeyPrefix: task.APIKeyPrefix, } } func asyncWorkerID() string { host, _ := os.Hostname() host = strings.TrimSpace(host) if host == "" { host = "localhost" } return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano()) } var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil)