217 lines
6.0 KiB
Go
217 lines
6.0 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
"github.com/riverqueue/river"
|
|
"github.com/riverqueue/river/riverdriver/riverpgxv5"
|
|
"github.com/riverqueue/river/rivermigrate"
|
|
"github.com/riverqueue/river/rivertype"
|
|
)
|
|
|
|
const asyncTaskQueueName = "gateway_tasks"
|
|
|
|
type asyncTaskArgs struct {
|
|
TaskID string `json:"task_id" river:"unique"`
|
|
}
|
|
|
|
func (asyncTaskArgs) Kind() string { return "gateway_task_run" }
|
|
|
|
type asyncTaskWorker struct {
|
|
river.WorkerDefaults[asyncTaskArgs]
|
|
|
|
service *Service
|
|
}
|
|
|
|
func (w *asyncTaskWorker) Work(ctx context.Context, job *river.Job[asyncTaskArgs]) error {
|
|
task, err := w.service.store.GetTask(ctx, job.Args.TaskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if task.Status == "succeeded" || task.Status == "failed" || task.Status == "cancelled" {
|
|
return nil
|
|
}
|
|
result, runErr := w.service.Execute(ctx, task, authUserFromTask(task))
|
|
if runErr == nil {
|
|
w.service.logger.Debug("river async task completed", "taskID", task.ID, "status", result.Task.Status, "riverJobID", job.ID)
|
|
return nil
|
|
}
|
|
var queuedErr *TaskQueuedError
|
|
if errors.As(runErr, &queuedErr) {
|
|
return river.JobSnooze(queuedErr.Delay)
|
|
}
|
|
if ctx.Err() != nil {
|
|
queued, queueErr := w.service.requeueInterruptedAsyncTask(context.WithoutCancel(ctx), task)
|
|
if queueErr != nil {
|
|
return queueErr
|
|
}
|
|
w.service.logger.Debug("river async task interrupted and requeued", "taskID", task.ID, "status", queued.Status, "riverJobID", job.ID)
|
|
return river.JobSnooze(0)
|
|
}
|
|
w.service.logger.Warn("river async task completed with failure", "taskID", task.ID, "error", runErr, "riverJobID", job.ID)
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) StartAsyncQueueWorker(ctx context.Context) {
|
|
if err := s.startRiverQueue(ctx); err != nil {
|
|
s.logger.Error("start river async queue failed", "error", err)
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func (s *Service) startRiverQueue(ctx context.Context) error {
|
|
driver := riverpgxv5.New(s.store.Pool())
|
|
migrator, err := rivermigrate.New(driver, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
workers := river.NewWorkers()
|
|
if err := river.AddWorkerSafely(workers, &asyncTaskWorker{service: s}); err != nil {
|
|
return err
|
|
}
|
|
riverClient, err := river.NewClient(driver, &river.Config{
|
|
ID: asyncWorkerID(),
|
|
JobTimeout: -1,
|
|
Logger: s.logger,
|
|
CompletedJobRetentionPeriod: 24 * time.Hour,
|
|
Queues: map[string]river.QueueConfig{
|
|
asyncTaskQueueName: {MaxWorkers: 32},
|
|
},
|
|
RescueStuckJobsAfter: 30 * time.Second,
|
|
TestOnly: s.cfg.AppEnv == "test",
|
|
Workers: workers,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.riverClient = riverClient
|
|
if err := riverClient.Start(ctx); err != nil {
|
|
return err
|
|
}
|
|
if err := s.recoverAsyncRiverJobs(ctx); err != nil {
|
|
return err
|
|
}
|
|
go func() {
|
|
<-ctx.Done()
|
|
stopCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
if err := riverClient.StopAndCancel(stopCtx); err != nil {
|
|
s.logger.Warn("stop river async queue failed", "error", err)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) EnqueueAsyncTask(ctx context.Context, task store.GatewayTask) error {
|
|
if s.riverClient == nil {
|
|
return errors.New("river async queue is not started")
|
|
}
|
|
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: task.ID}, asyncTaskInsertOpts(task))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.Job != nil {
|
|
return s.store.SetTaskRiverJobID(ctx, task.ID, result.Job.ID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) WakeAsyncQueueAfter(ctx context.Context, delay time.Duration) {
|
|
}
|
|
|
|
func (s *Service) RunAsyncTask(ctx context.Context, task store.GatewayTask, user *auth.User) {
|
|
if err := s.EnqueueAsyncTask(ctx, task); err != nil {
|
|
s.logger.Warn("enqueue river async task failed", "taskID", task.ID, "error", err)
|
|
}
|
|
}
|
|
|
|
func (s *Service) recoverAsyncRiverJobs(ctx context.Context) error {
|
|
items, err := s.store.ListRecoverableAsyncTasks(ctx, 1000)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, item := range items {
|
|
task := store.GatewayTask{ID: item.ID}
|
|
result, err := s.riverClient.Insert(ctx, asyncTaskArgs{TaskID: item.ID}, asyncTaskInsertOpts(task))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.Job != nil {
|
|
if err := s.store.SetTaskRiverJobID(ctx, item.ID, result.Job.ID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
if len(items) > 0 {
|
|
s.logger.Info("river async queue recovered persisted tasks", "count", len(items))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func asyncTaskInsertOpts(task store.GatewayTask) *river.InsertOpts {
|
|
priority := 2
|
|
if task.ID == "" {
|
|
priority = 3
|
|
}
|
|
return &river.InsertOpts{
|
|
MaxAttempts: 1000,
|
|
Priority: priority,
|
|
Queue: asyncTaskQueueName,
|
|
Tags: []string{"gateway-task"},
|
|
UniqueOpts: river.UniqueOpts{
|
|
ByArgs: true,
|
|
ByQueue: true,
|
|
ByState: []rivertype.JobState{
|
|
rivertype.JobStateAvailable,
|
|
rivertype.JobStatePending,
|
|
rivertype.JobStateRetryable,
|
|
rivertype.JobStateRunning,
|
|
rivertype.JobStateScheduled,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func authUserFromTask(task store.GatewayTask) *auth.User {
|
|
roles := []string{"user"}
|
|
if strings.TrimSpace(task.UserID) == "" {
|
|
roles = nil
|
|
}
|
|
return &auth.User{
|
|
ID: firstNonEmptyString(task.GatewayUserID, task.UserID),
|
|
Roles: roles,
|
|
TenantID: task.TenantID,
|
|
GatewayTenantID: task.GatewayTenantID,
|
|
TenantKey: task.TenantKey,
|
|
Source: firstNonEmptyString(task.UserSource, "gateway"),
|
|
GatewayUserID: task.GatewayUserID,
|
|
UserGroupID: task.UserGroupID,
|
|
UserGroupKey: task.UserGroupKey,
|
|
APIKeyID: task.APIKeyID,
|
|
APIKeyName: task.APIKeyName,
|
|
APIKeyPrefix: task.APIKeyPrefix,
|
|
}
|
|
}
|
|
|
|
func asyncWorkerID() string {
|
|
host, _ := os.Hostname()
|
|
host = strings.TrimSpace(host)
|
|
if host == "" {
|
|
host = "localhost"
|
|
}
|
|
return fmt.Sprintf("%s:%d:%d", host, os.Getpid(), time.Now().UnixNano())
|
|
}
|
|
|
|
var _ river.Worker[asyncTaskArgs] = (*asyncTaskWorker)(nil)
|