151 lines
4.8 KiB
Go
151 lines
4.8 KiB
Go
package runner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
|
"github.com/riverqueue/river/rivertype"
|
|
)
|
|
|
|
var ErrTaskAccessDenied = errors.New("task access denied")
|
|
|
|
type TaskCancelResult struct {
|
|
TaskID string `json:"taskId"`
|
|
Cancelled bool `json:"cancelled"`
|
|
Cancellable bool `json:"cancellable"`
|
|
Submitted bool `json:"submitted"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) {
|
|
task, err := s.store.GetTask(ctx, taskID)
|
|
if err != nil {
|
|
return TaskCancelResult{}, err
|
|
}
|
|
if !taskAccessibleToUser(task, user) {
|
|
return TaskCancelResult{}, ErrTaskAccessDenied
|
|
}
|
|
if taskCancelTerminalStatus(task.Status) {
|
|
return taskCancelUnavailable(task, "任务已结束,无法取消"), nil
|
|
}
|
|
if strings.TrimSpace(task.RemoteTaskID) != "" {
|
|
return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil
|
|
}
|
|
if strings.TrimSpace(task.Status) != "queued" {
|
|
return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil
|
|
}
|
|
if task.RiverJobID > 0 {
|
|
if s.riverClient == nil {
|
|
return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil
|
|
}
|
|
job, err := s.riverClient.JobGet(ctx, task.RiverJobID)
|
|
if errors.Is(err, rivertype.ErrNotFound) {
|
|
return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil
|
|
}
|
|
if err != nil {
|
|
return TaskCancelResult{}, err
|
|
}
|
|
if job == nil || !riverJobStateCancellable(job.State) {
|
|
return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil
|
|
}
|
|
if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil {
|
|
if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) {
|
|
return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil
|
|
}
|
|
return TaskCancelResult{}, err
|
|
}
|
|
}
|
|
|
|
cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消")
|
|
if err != nil {
|
|
return TaskCancelResult{}, err
|
|
}
|
|
if !cancelled {
|
|
latest, latestErr := s.store.GetTask(ctx, task.ID)
|
|
if latestErr == nil {
|
|
return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil
|
|
}
|
|
return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil
|
|
}
|
|
if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{
|
|
"taskId": cancelledTask.ID,
|
|
"reason": "manual_cancel",
|
|
}, cancelledTask.RunMode == "simulation"); err != nil {
|
|
return TaskCancelResult{}, err
|
|
}
|
|
return TaskCancelResult{
|
|
TaskID: cancelledTask.ID,
|
|
Cancelled: true,
|
|
Cancellable: true,
|
|
Submitted: false,
|
|
Message: "任务已取消",
|
|
}, nil
|
|
}
|
|
|
|
func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult {
|
|
return TaskCancelResult{
|
|
TaskID: task.ID,
|
|
Cancelled: false,
|
|
Cancellable: false,
|
|
Submitted: strings.TrimSpace(task.RemoteTaskID) != "",
|
|
Message: message,
|
|
}
|
|
}
|
|
|
|
func taskCancelTerminalStatus(status string) bool {
|
|
switch strings.TrimSpace(status) {
|
|
case "succeeded", "failed", "cancelled":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func riverJobStateCancellable(state rivertype.JobState) bool {
|
|
switch state {
|
|
case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool {
|
|
if user == nil {
|
|
return false
|
|
}
|
|
if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" {
|
|
if strings.TrimSpace(task.APIKeyID) != apiKeyID {
|
|
return false
|
|
}
|
|
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
|
|
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
|
|
}
|
|
if userID := strings.TrimSpace(user.ID); userID != "" {
|
|
return strings.TrimSpace(task.UserID) == userID
|
|
}
|
|
return true
|
|
}
|
|
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
|
|
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
|
|
}
|
|
userID := strings.TrimSpace(user.ID)
|
|
return userID != "" && strings.TrimSpace(task.UserID) == userID
|
|
}
|
|
|
|
func gatewayUserIDForAuth(user *auth.User) string {
|
|
if user == nil {
|
|
return ""
|
|
}
|
|
if user.GatewayUserID != "" {
|
|
return strings.TrimSpace(user.GatewayUserID)
|
|
}
|
|
if user.Source == "" || user.Source == "gateway" {
|
|
return strings.TrimSpace(user.ID)
|
|
}
|
|
return ""
|
|
}
|