feat(api): 支持取消本地排队任务

This commit is contained in:
chengcheng 2026-06-09 19:09:20 +08:00
parent a8fa8dd212
commit e8df26da9b
6 changed files with 319 additions and 0 deletions

View File

@ -1452,6 +1452,7 @@ func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks [get]
// @Router /api/v1/tasks [get]
// @Router /tasks [get]
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
@ -1560,6 +1561,7 @@ func boolValue(body map[string]any, key string) bool {
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID} [get]
// @Router /api/v1/tasks/{taskID} [get]
// @Router /tasks/{taskID} [get]
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err == nil {
@ -1574,6 +1576,45 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusInternalServerError, "get task failed")
}
// cancelTask godoc
// @Summary 取消异步任务
// @Description 仅取消仍在网关本地排队且尚未提交上游的任务;任务已运行、已提交上游或已结束时返回不可取消,客户端应继续查询结果。
// @Tags tasks
// @Produce json
// @Security BearerAuth
// @Param taskID path string true "任务 ID"
// @Success 200 {object} TaskCancelResponse
// @Failure 401 {object} ErrorEnvelope
// @Failure 403 {object} ErrorEnvelope
// @Failure 404 {object} ErrorEnvelope
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID}/cancel [post]
// @Router /api/v1/tasks/{taskID}/cancel [post]
// @Router /v1/tasks/{taskID}/cancel [post]
// @Router /tasks/{taskID}/cancel [post]
func (s *Server) cancelTask(w http.ResponseWriter, r *http.Request) {
user, ok := auth.UserFromContext(r.Context())
if !ok {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
result, err := s.runner.CancelTask(r.Context(), r.PathValue("taskID"), user)
if err == nil {
writeJSON(w, http.StatusOK, result)
return
}
if store.IsNotFound(err) {
writeError(w, http.StatusNotFound, "task not found")
return
}
if errors.Is(err, runner.ErrTaskAccessDenied) {
writeError(w, http.StatusForbidden, "task access denied")
return
}
s.logger.Error("cancel task failed", "error", err)
writeError(w, http.StatusInternalServerError, "cancel task failed")
}
// taskParamPreprocessing godoc
// @Summary 获取任务参数预处理日志
// @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。
@ -1587,6 +1628,7 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID}/param-preprocessing [get]
// @Router /api/v1/tasks/{taskID}/param-preprocessing [get]
// @Router /tasks/{taskID}/param-preprocessing [get]
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err != nil {
@ -1620,6 +1662,7 @@ func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request)
// @Failure 500 {object} ErrorEnvelope
// @Router /api/workspace/tasks/{taskID}/events [get]
// @Router /api/v1/tasks/{taskID}/events [get]
// @Router /tasks/{taskID}/events [get]
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
if err != nil {

View File

@ -146,6 +146,14 @@ type TaskAcceptedResponse struct {
Next TaskNextLinks `json:"next"`
}
type TaskCancelResponse struct {
TaskID string `json:"taskId" example:"9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
Cancelled bool `json:"cancelled" example:"false"`
Cancellable bool `json:"cancellable" example:"false"`
Submitted bool `json:"submitted" example:"true"`
Message string `json:"message" example:"任务已提交上游,当前不可取消,请继续查询结果"`
}
type TaskNextLinks struct {
Events string `json:"events" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25/events"`
Detail string `json:"detail" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`

View File

@ -93,6 +93,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("GET /api/workspace/wallet/transactions", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listWalletTransactions)))
mux.Handle("GET /api/workspace/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
mux.Handle("GET /api/workspace/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
mux.Handle("POST /api/workspace/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
mux.Handle("GET /api/workspace/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
mux.Handle("GET /api/workspace/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules)))
@ -144,8 +145,14 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
mux.Handle("POST /api/v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
mux.Handle("GET /tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
mux.Handle("GET /tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
mux.Handle("POST /tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
mux.Handle("GET /tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
mux.Handle("GET /tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
mux.Handle("POST /responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true)))
@ -165,6 +172,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
mux.Handle("POST /speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
mux.Handle("POST /v1/speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
mux.Handle("POST /v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
return server.recover(server.cors(mux))
}

View File

@ -0,0 +1,150 @@
package runner
import (
"context"
"errors"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/riverqueue/river/rivertype"
)
var ErrTaskAccessDenied = errors.New("task access denied")
type TaskCancelResult struct {
TaskID string `json:"taskId"`
Cancelled bool `json:"cancelled"`
Cancellable bool `json:"cancellable"`
Submitted bool `json:"submitted"`
Message string `json:"message"`
}
func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) {
task, err := s.store.GetTask(ctx, taskID)
if err != nil {
return TaskCancelResult{}, err
}
if !taskAccessibleToUser(task, user) {
return TaskCancelResult{}, ErrTaskAccessDenied
}
if taskCancelTerminalStatus(task.Status) {
return taskCancelUnavailable(task, "任务已结束,无法取消"), nil
}
if strings.TrimSpace(task.RemoteTaskID) != "" {
return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil
}
if strings.TrimSpace(task.Status) != "queued" {
return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil
}
if task.RiverJobID > 0 {
if s.riverClient == nil {
return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil
}
job, err := s.riverClient.JobGet(ctx, task.RiverJobID)
if errors.Is(err, rivertype.ErrNotFound) {
return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil
}
if err != nil {
return TaskCancelResult{}, err
}
if job == nil || !riverJobStateCancellable(job.State) {
return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil
}
if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil {
if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) {
return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil
}
return TaskCancelResult{}, err
}
}
cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消")
if err != nil {
return TaskCancelResult{}, err
}
if !cancelled {
latest, latestErr := s.store.GetTask(ctx, task.ID)
if latestErr == nil {
return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil
}
return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil
}
if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{
"taskId": cancelledTask.ID,
"reason": "manual_cancel",
}, cancelledTask.RunMode == "simulation"); err != nil {
return TaskCancelResult{}, err
}
return TaskCancelResult{
TaskID: cancelledTask.ID,
Cancelled: true,
Cancellable: true,
Submitted: false,
Message: "任务已取消",
}, nil
}
func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult {
return TaskCancelResult{
TaskID: task.ID,
Cancelled: false,
Cancellable: false,
Submitted: strings.TrimSpace(task.RemoteTaskID) != "",
Message: message,
}
}
func taskCancelTerminalStatus(status string) bool {
switch strings.TrimSpace(status) {
case "succeeded", "failed", "cancelled":
return true
default:
return false
}
}
func riverJobStateCancellable(state rivertype.JobState) bool {
switch state {
case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending:
return true
default:
return false
}
}
func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool {
if user == nil {
return false
}
if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" {
if strings.TrimSpace(task.APIKeyID) != apiKeyID {
return false
}
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
}
if userID := strings.TrimSpace(user.ID); userID != "" {
return strings.TrimSpace(task.UserID) == userID
}
return true
}
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
}
userID := strings.TrimSpace(user.ID)
return userID != "" && strings.TrimSpace(task.UserID) == userID
}
func gatewayUserIDForAuth(user *auth.User) string {
if user == nil {
return ""
}
if user.GatewayUserID != "" {
return strings.TrimSpace(user.GatewayUserID)
}
if user.Source == "" || user.Source == "gateway" {
return strings.TrimSpace(user.ID)
}
return ""
}

View File

@ -0,0 +1,78 @@
package runner
import (
"testing"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
"github.com/riverqueue/river/rivertype"
)
func TestRiverJobStateCancellableOnlyAllowsLocalQueueStates(t *testing.T) {
cancellableStates := []rivertype.JobState{
rivertype.JobStateAvailable,
rivertype.JobStateScheduled,
rivertype.JobStateRetryable,
rivertype.JobStatePending,
}
for _, state := range cancellableStates {
if !riverJobStateCancellable(state) {
t.Fatalf("expected %s to be cancellable", state)
}
}
nonCancellableStates := []rivertype.JobState{
rivertype.JobStateRunning,
rivertype.JobStateCancelled,
rivertype.JobStateCompleted,
rivertype.JobStateDiscarded,
}
for _, state := range nonCancellableStates {
if riverJobStateCancellable(state) {
t.Fatalf("expected %s to be non-cancellable", state)
}
}
}
func TestTaskAccessibleToUserHonorsAPIKeyAndGatewayUser(t *testing.T) {
task := store.GatewayTask{
UserID: "server-user",
GatewayUserID: "gateway-user",
APIKeyID: "api-key",
}
if !taskAccessibleToUser(task, &auth.User{GatewayUserID: "gateway-user"}) {
t.Fatal("gateway user should access own task")
}
if !taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "gateway-user"}) {
t.Fatal("api key owner should access own task")
}
if taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "other-user"}) {
t.Fatal("api key from another gateway user should not access task")
}
if taskAccessibleToUser(task, &auth.User{GatewayUserID: "other-user"}) {
t.Fatal("another gateway user should not access task")
}
serverMainTask := store.GatewayTask{
UserID: "server-user",
APIKeyID: "server-api-key",
}
if !taskAccessibleToUser(serverMainTask, &auth.User{ID: "server-user", Source: "server-main", APIKeyID: "server-api-key"}) {
t.Fatal("server-main api key should fall back to user id when gateway user id is absent")
}
if taskAccessibleToUser(serverMainTask, &auth.User{ID: "other-user", Source: "server-main", APIKeyID: "server-api-key"}) {
t.Fatal("server-main api key from another user should not access task")
}
}
func TestTaskCancelUnavailableReportsSubmittedFromRemoteTaskID(t *testing.T) {
result := taskCancelUnavailable(store.GatewayTask{ID: "task-1", RemoteTaskID: "remote-1"}, "不可取消")
if result.Cancelled || result.Cancellable {
t.Fatalf("unavailable result should not mark cancellation as done: %+v", result)
}
if !result.Submitted {
t.Fatalf("remote task id should report submitted: %+v", result)
}
}

View File

@ -256,6 +256,38 @@ WHERE id = $1::uuid`,
})
}
func (s *Store) CancelQueuedTask(ctx context.Context, taskID string, message string) (GatewayTask, bool, error) {
message = strings.TrimSpace(message)
if message == "" {
message = "任务已取消"
}
tag, err := s.pool.Exec(ctx, `
UPDATE gateway_tasks
SET status = 'cancelled',
error = NULLIF($2::text, ''),
error_code = 'task_cancelled',
error_message = NULLIF($2::text, ''),
locked_by = NULL,
locked_at = NULL,
heartbeat_at = NULL,
finished_at = now(),
updated_at = now()
WHERE id = $1::uuid
AND status = 'queued'
AND COALESCE(remote_task_id, '') = ''`, taskID, message)
if err != nil {
return GatewayTask{}, false, err
}
if tag.RowsAffected() == 0 {
return GatewayTask{}, false, nil
}
task, err := s.GetTask(ctx, taskID)
if err != nil {
return GatewayTask{}, true, err
}
return task, true, nil
}
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
if limit <= 0 {
limit = 500