easyai-ai-gateway/apps/api/internal/runner/queue_worker.go

217 lines
6.0 KiB
Go

package runner
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/riverqueue/river"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivermigrate"
"github.com/riverqueue/river/rivertype"
)
const asyncTaskQueueName = "gateway_tasks"
type asyncTaskArgs struct {
TaskID string `json:"task_id" river:"unique"`
}
func (asyncTaskArgs) Kind() string { return "gateway_task_run" }
type asyncTaskWorker struct {
river.WorkerDefaults[asyncTaskArgs]
service *Service
}
func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error {
task, err := w.service.store.GetTask(ctx, job.Args.TaskID)
if err != nil {
return err
}
if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" {
return nil
}
result, runErr := w.service.Execute(ctx, task, authUserFromTask(task))
if runErr == nil {
w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID)
return nil
}
var queuedErr *TaskQueuedError
if errors.As(runErr, &queuedErr) {
return river.JobSnooze(queuedErr.Delay)
}
if ctx.Err() != nil {
queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
if queueErr != nil {
return queueErr
}
w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID)
return river.JobSnooze(0)
}
w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID)
return nil
}
func (s *Service) StartAsyncQueueWorker(ctx context.Context) {
if err := s.startRiverQueue(ctx); err != nil {
s.logger.Error("start river async queue failed", "error", err)
panic(err)
}
}
func (s *Service) startRiverQueue(ctx context.Context) error {
driver := riverpgxv5.New(s.store.Pool())
migrator, err := rivermigrate.New(driver, nil)
if err != nil {
return err
}
if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil {
return err
}
workers := river.NewWorkers()
if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil {
return err
}
riverClient, err := river.NewClient(driver, &river.Config{
ID: asyncWorkerID(),
JobTimeout: -1,
Logger: s.logger,
CompletedJobRetentionPeriod: 24 * time.Hour,
Queues: map[string]river.QueueConfig{
asyncTaskQueueName: {MaxWorkers: 32},
},
RescueStuckJobsAfter: 30 * time.Second,
TestOnly: s.cfg.AppEnv == "test",
Workers: workers,
})
if err != nil {
return err
}
s.riverClient = riverClient
if err := riverClient.Start(ctx); err != nil {
return err
}
if err := s.recoverAsyncRiverJobs(ctx); err != nil {
return err
}
go func() {
<-ctx.Done()
stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := riverClient.StopAndCancel(stopCtx); err != nil {
s.logger.Warn("stop river async queue failed", "error", err)
}
}()
return nil
}
func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error {
if s.riverClient == nil {
return errors.New("river async queue is not started")
}
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task))
if err != nil {
return err
}
if result.Job != nil {
return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID)
}
return nil
}
func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) {
}
func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) {
if err := s.EnqueueAsyncTask(ctx, task); err != nil {
s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err)
}
}
func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error {
items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000)
if err != nil {
return err
}
for _, item := range items {
task := store.GatewayTask{ID: item.ID}
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task))
if err != nil {
return err
}
if result.Job != nil {
if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil {
return err
}
}
}
if len(items) > 0 {
s.logger.Info("river async queue recovered persisted tasks", "count", len(items))
}
return nil
}
func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts {
priority := 2
if task.ID == "" {
priority = 3
}
return &river.InsertOpts{
MaxAttempts: 1000,
Priority: priority,
Queue: asyncTaskQueueName,
Tags: []string{"gateway-task"},
UniqueOpts: river.UniqueOpts{
ByArgs: true,
ByQueue: true,
ByState: []rivertype.JobState{
rivertype.JobStateAvailable,
rivertype.JobStatePending,
rivertype.JobStateRetryable,
rivertype.JobStateRunning,
rivertype.JobStateScheduled,
},
},
}
}
func authUserFromTask(task store.GatewayTask) *auth.User {
roles := []string{"user"}
if strings.TrimSpace(task.UserID) == "" {
roles = nil
}
return &auth.User{
ID: firstNonEmptyString(task.GatewayUserID, task.UserID),
Roles: roles,
TenantID: task.TenantID,
GatewayTenantID: task.GatewayTenantID,
TenantKey: task.TenantKey,
Source: firstNonEmptyString(task.UserSource, "gateway"),
GatewayUserID: task.GatewayUserID,
UserGroupID: task.UserGroupID,
UserGroupKey: task.UserGroupKey,
APIKeyID: task.APIKeyID,
APIKeyName: task.APIKeyName,
APIKeyPrefix: task.APIKeyPrefix,
}
}
func asyncWorkerID() string {
host, _ := os.Hostname()
host = strings.TrimSpace(host)
if host == "" {
host = "localhost"
}
return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano())
}
var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil)