package runner import ( "context" "errors" "strings" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" "github.com/riverqueue/river/rivertype" ) var ErrTaskAccessDenied = errors.New("task access denied") type TaskCancelResult struct { TaskID string `json:"taskId"` Cancelled bool `json:"cancelled"` Cancellable bool `json:"cancellable"` Submitted bool `json:"submitted"` Message string `json:"message"` } func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) { task, err := s.store.GetTask(ctx, taskID) if err != nil { return TaskCancelResult{}, err } if !taskAccessibleToUser(task, user) { return TaskCancelResult{}, ErrTaskAccessDenied } if taskCancelTerminalStatus(task.Status) { return taskCancelUnavailable(task, "任务已结束,无法取消"), nil } if strings.TrimSpace(task.RemoteTaskID) != "" { return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil } if strings.TrimSpace(task.Status) != "queued" { return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil } if task.RiverJobID > 0 { if s.riverClient == nil { return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil } job, err := s.riverClient.JobGet(ctx, task.RiverJobID) if errors.Is(err, rivertype.ErrNotFound) { return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil } if err != nil { return TaskCancelResult{}, err } if job == nil || !riverJobStateCancellable(job.State) { return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil } if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil { if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) { return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil } return TaskCancelResult{}, err } } cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消") if err != nil { return TaskCancelResult{}, err } if !cancelled { latest, latestErr := s.store.GetTask(ctx, task.ID) if latestErr == nil { return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil } return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil } if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{ "taskId": cancelledTask.ID, "reason": "manual_cancel", }, cancelledTask.RunMode == "simulation"); err != nil { return TaskCancelResult{}, err } return TaskCancelResult{ TaskID: cancelledTask.ID, Cancelled: true, Cancellable: true, Submitted: false, Message: "任务已取消", }, nil } func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult { return TaskCancelResult{ TaskID: task.ID, Cancelled: false, Cancellable: false, Submitted: strings.TrimSpace(task.RemoteTaskID) != "", Message: message, } } func taskCancelTerminalStatus(status string) bool { switch strings.TrimSpace(status) { case "succeeded", "failed", "cancelled": return true default: return false } } func riverJobStateCancellable(state rivertype.JobState) bool { switch state { case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending: return true default: return false } } func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool { if user == nil { return false } if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" { if strings.TrimSpace(task.APIKeyID) != apiKeyID { return false } if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" { return strings.TrimSpace(task.GatewayUserID) == gatewayUserID } if userID := strings.TrimSpace(user.ID); userID != "" { return strings.TrimSpace(task.UserID) == userID } return true } if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" { return strings.TrimSpace(task.GatewayUserID) == gatewayUserID } userID := strings.TrimSpace(user.ID) return userID != "" && strings.TrimSpace(task.UserID) == userID } func gatewayUserIDForAuth(user *auth.User) string { if user == nil { return "" } if user.GatewayUserID != "" { return strings.TrimSpace(user.GatewayUserID) } if user.Source == "" || user.Source == "gateway" { return strings.TrimSpace(user.ID) } return "" }