easyai-ai-gateway/apps/api/internal/runner/task_cancel.go
2026-06-09 19:09:20 +08:00

151 lines
4.8 KiB
Go

package runner
import (
"context"
"errors"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/riverqueue/river/rivertype"
)
var ErrTaskAccessDenied = errors.New("task access denied")
type TaskCancelResult struct {
TaskID string `json:"taskId"`
Cancelled bool `json:"cancelled"`
Cancellable bool `json:"cancellable"`
Submitted bool `json:"submitted"`
Message string `json:"message"`
}
func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) {
task, err := s.store.GetTask(ctx, taskID)
if err != nil {
return TaskCancelResult{}, err
}
if !taskAccessibleToUser(task, user) {
return TaskCancelResult{}, ErrTaskAccessDenied
}
if taskCancelTerminalStatus(task.Status) {
return taskCancelUnavailable(task, "任务已结束,无法取消"), nil
}
if strings.TrimSpace(task.RemoteTaskID) != "" {
return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil
}
if strings.TrimSpace(task.Status) != "queued" {
return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil
}
if task.RiverJobID > 0 {
if s.riverClient == nil {
return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil
}
job, err := s.riverClient.JobGet(ctx, task.RiverJobID)
if errors.Is(err, rivertype.ErrNotFound) {
return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil
}
if err != nil {
return TaskCancelResult{}, err
}
if job == nil || !riverJobStateCancellable(job.State) {
return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil
}
if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil {
if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) {
return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil
}
return TaskCancelResult{}, err
}
}
cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消")
if err != nil {
return TaskCancelResult{}, err
}
if !cancelled {
latest, latestErr := s.store.GetTask(ctx, task.ID)
if latestErr == nil {
return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil
}
return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil
}
if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{
"taskId": cancelledTask.ID,
"reason": "manual_cancel",
}, cancelledTask.RunMode == "simulation"); err != nil {
return TaskCancelResult{}, err
}
return TaskCancelResult{
TaskID: cancelledTask.ID,
Cancelled: true,
Cancellable: true,
Submitted: false,
Message: "任务已取消",
}, nil
}
func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult {
return TaskCancelResult{
TaskID: task.ID,
Cancelled: false,
Cancellable: false,
Submitted: strings.TrimSpace(task.RemoteTaskID) != "",
Message: message,
}
}
func taskCancelTerminalStatus(status string) bool {
switch strings.TrimSpace(status) {
case "succeeded", "failed", "cancelled":
return true
default:
return false
}
}
func riverJobStateCancellable(state rivertype.JobState) bool {
switch state {
case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending:
return true
default:
return false
}
}
func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool {
if user == nil {
return false
}
if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" {
if strings.TrimSpace(task.APIKeyID) != apiKeyID {
return false
}
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
}
if userID := strings.TrimSpace(user.ID); userID != "" {
return strings.TrimSpace(task.UserID) == userID
}
return true
}
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
}
userID := strings.TrimSpace(user.ID)
return userID != "" && strings.TrimSpace(task.UserID) == userID
}
func gatewayUserIDForAuth(user *auth.User) string {
if user == nil {
return ""
}
if user.GatewayUserID != "" {
return strings.TrimSpace(user.GatewayUserID)
}
if user.Source == "" || user.Source == "gateway" {
return strings.TrimSpace(user.ID)
}
return ""
}