From e8df26da9baafa6d28e1bcdd60d3b6db88c6dc57 Mon Sep 17 00:00:00 2001 From: chengcheng Date: Tue, 9 Jun 2026 19:09:20 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E6=94=AF=E6=8C=81=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=9C=AC=E5=9C=B0=E6=8E=92=E9=98=9F=E4=BB=BB=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/api/internal/httpapi/handlers.go | 43 ++++++ apps/api/internal/httpapi/openapi_models.go | 8 + apps/api/internal/httpapi/server.go | 8 + apps/api/internal/runner/task_cancel.go | 150 +++++++++++++++++++ apps/api/internal/runner/task_cancel_test.go | 78 ++++++++++ apps/api/internal/store/tasks_runtime.go | 32 ++++ 6 files changed, 319 insertions(+) create mode 100644 apps/api/internal/runner/task_cancel.go create mode 100644 apps/api/internal/runner/task_cancel_test.go diff --git a/apps/api/internal/httpapi/handlers.go b/apps/api/internal/httpapi/handlers.go index 7a844d0..c59d87c 100644 --- a/apps/api/internal/httpapi/handlers.go +++ b/apps/api/internal/httpapi/handlers.go @@ -1452,6 +1452,7 @@ func matchedRateLimitRule(policy map[string]any, metric string) map[string]any { // @Failure 500 {object} ErrorEnvelope // @Router /api/workspace/tasks [get] // @Router /api/v1/tasks [get] +// @Router /tasks [get] func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) { user, ok := auth.UserFromContext(r.Context()) if !ok { @@ -1560,6 +1561,7 @@ func boolValue(body map[string]any, key string) bool { // @Failure 500 {object} ErrorEnvelope // @Router /api/workspace/tasks/{taskID} [get] // @Router /api/v1/tasks/{taskID} [get] +// @Router /tasks/{taskID} [get] func (s *Server) getTask(w http.ResponseWriter, r *http.Request) { task, err := s.store.GetTask(r.Context(), r.PathValue("taskID")) if err == nil { @@ -1574,6 +1576,45 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusInternalServerError, "get task failed") } +// cancelTask godoc +// @Summary 取消异步任务 +// @Description 仅取消仍在网关本地排队且尚未提交上游的任务;任务已运行、已提交上游或已结束时返回不可取消,客户端应继续查询结果。 +// @Tags tasks +// @Produce json +// @Security BearerAuth +// @Param taskID path string true "任务 ID" +// @Success 200 {object} TaskCancelResponse +// @Failure 401 {object} ErrorEnvelope +// @Failure 403 {object} ErrorEnvelope +// @Failure 404 {object} ErrorEnvelope +// @Failure 500 {object} ErrorEnvelope +// @Router /api/workspace/tasks/{taskID}/cancel [post] +// @Router /api/v1/tasks/{taskID}/cancel [post] +// @Router /v1/tasks/{taskID}/cancel [post] +// @Router /tasks/{taskID}/cancel [post] +func (s *Server) cancelTask(w http.ResponseWriter, r *http.Request) { + user, ok := auth.UserFromContext(r.Context()) + if !ok { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + result, err := s.runner.CancelTask(r.Context(), r.PathValue("taskID"), user) + if err == nil { + writeJSON(w, http.StatusOK, result) + return + } + if store.IsNotFound(err) { + writeError(w, http.StatusNotFound, "task not found") + return + } + if errors.Is(err, runner.ErrTaskAccessDenied) { + writeError(w, http.StatusForbidden, "task access denied") + return + } + s.logger.Error("cancel task failed", "error", err) + writeError(w, http.StatusInternalServerError, "cancel task failed") +} + // taskParamPreprocessing godoc // @Summary 获取任务参数预处理日志 // @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。 @@ -1587,6 +1628,7 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} ErrorEnvelope // @Router /api/workspace/tasks/{taskID}/param-preprocessing [get] // @Router /api/v1/tasks/{taskID}/param-preprocessing [get] +// @Router /tasks/{taskID}/param-preprocessing [get] func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) { task, err := s.store.GetTask(r.Context(), r.PathValue("taskID")) if err != nil { @@ -1620,6 +1662,7 @@ func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) // @Failure 500 {object} ErrorEnvelope // @Router /api/workspace/tasks/{taskID}/events [get] // @Router /api/v1/tasks/{taskID}/events [get] +// @Router /tasks/{taskID}/events [get] func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) { task, err := s.store.GetTask(r.Context(), r.PathValue("taskID")) if err != nil { diff --git a/apps/api/internal/httpapi/openapi_models.go b/apps/api/internal/httpapi/openapi_models.go index f82ac91..d07ad79 100644 --- a/apps/api/internal/httpapi/openapi_models.go +++ b/apps/api/internal/httpapi/openapi_models.go @@ -146,6 +146,14 @@ type TaskAcceptedResponse struct { Next TaskNextLinks `json:"next"` } +type TaskCancelResponse struct { + TaskID string `json:"taskId" example:"9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"` + Cancelled bool `json:"cancelled" example:"false"` + Cancellable bool `json:"cancellable" example:"false"` + Submitted bool `json:"submitted" example:"true"` + Message string `json:"message" example:"任务已提交上游,当前不可取消,请继续查询结果"` +} + type TaskNextLinks struct { Events string `json:"events" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25/events"` Detail string `json:"detail" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"` diff --git a/apps/api/internal/httpapi/server.go b/apps/api/internal/httpapi/server.go index 8c9f74b..7cb11fe 100644 --- a/apps/api/internal/httpapi/server.go +++ b/apps/api/internal/httpapi/server.go @@ -93,6 +93,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("GET /api/workspace/wallet/transactions", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listWalletTransactions))) mux.Handle("GET /api/workspace/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) mux.Handle("GET /api/workspace/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) + mux.Handle("POST /api/workspace/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask))) mux.Handle("GET /api/workspace/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing))) mux.Handle("GET /api/workspace/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents))) mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules))) @@ -144,8 +145,14 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile))) mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) + mux.Handle("POST /api/v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask))) mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing))) mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents))) + mux.Handle("GET /tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks))) + mux.Handle("GET /tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask))) + mux.Handle("POST /tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask))) + mux.Handle("GET /tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing))) + mux.Handle("GET /tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents))) mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true))) mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true))) mux.Handle("POST /responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true))) @@ -165,6 +172,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor mux.Handle("POST /speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true))) mux.Handle("POST /v1/speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true))) mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile))) + mux.Handle("POST /v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask))) return server.recover(server.cors(mux)) } diff --git a/apps/api/internal/runner/task_cancel.go b/apps/api/internal/runner/task_cancel.go new file mode 100644 index 0000000..e35400d --- /dev/null +++ b/apps/api/internal/runner/task_cancel.go @@ -0,0 +1,150 @@ +package runner + +import ( + "context" + "errors" + "strings" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" + "github.com/riverqueue/river/rivertype" +) + +var ErrTaskAccessDenied = errors.New("task access denied") + +type TaskCancelResult struct { + TaskID string `json:"taskId"` + Cancelled bool `json:"cancelled"` + Cancellable bool `json:"cancellable"` + Submitted bool `json:"submitted"` + Message string `json:"message"` +} + +func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) { + task, err := s.store.GetTask(ctx, taskID) + if err != nil { + return TaskCancelResult{}, err + } + if !taskAccessibleToUser(task, user) { + return TaskCancelResult{}, ErrTaskAccessDenied + } + if taskCancelTerminalStatus(task.Status) { + return taskCancelUnavailable(task, "任务已结束,无法取消"), nil + } + if strings.TrimSpace(task.RemoteTaskID) != "" { + return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil + } + if strings.TrimSpace(task.Status) != "queued" { + return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil + } + if task.RiverJobID > 0 { + if s.riverClient == nil { + return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil + } + job, err := s.riverClient.JobGet(ctx, task.RiverJobID) + if errors.Is(err, rivertype.ErrNotFound) { + return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil + } + if err != nil { + return TaskCancelResult{}, err + } + if job == nil || !riverJobStateCancellable(job.State) { + return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil + } + if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil { + if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) { + return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil + } + return TaskCancelResult{}, err + } + } + + cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消") + if err != nil { + return TaskCancelResult{}, err + } + if !cancelled { + latest, latestErr := s.store.GetTask(ctx, task.ID) + if latestErr == nil { + return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil + } + return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil + } + if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{ + "taskId": cancelledTask.ID, + "reason": "manual_cancel", + }, cancelledTask.RunMode == "simulation"); err != nil { + return TaskCancelResult{}, err + } + return TaskCancelResult{ + TaskID: cancelledTask.ID, + Cancelled: true, + Cancellable: true, + Submitted: false, + Message: "任务已取消", + }, nil +} + +func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult { + return TaskCancelResult{ + TaskID: task.ID, + Cancelled: false, + Cancellable: false, + Submitted: strings.TrimSpace(task.RemoteTaskID) != "", + Message: message, + } +} + +func taskCancelTerminalStatus(status string) bool { + switch strings.TrimSpace(status) { + case "succeeded", "failed", "cancelled": + return true + default: + return false + } +} + +func riverJobStateCancellable(state rivertype.JobState) bool { + switch state { + case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending: + return true + default: + return false + } +} + +func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool { + if user == nil { + return false + } + if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" { + if strings.TrimSpace(task.APIKeyID) != apiKeyID { + return false + } + if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" { + return strings.TrimSpace(task.GatewayUserID) == gatewayUserID + } + if userID := strings.TrimSpace(user.ID); userID != "" { + return strings.TrimSpace(task.UserID) == userID + } + return true + } + if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" { + return strings.TrimSpace(task.GatewayUserID) == gatewayUserID + } + userID := strings.TrimSpace(user.ID) + return userID != "" && strings.TrimSpace(task.UserID) == userID +} + +func gatewayUserIDForAuth(user *auth.User) string { + if user == nil { + return "" + } + if user.GatewayUserID != "" { + return strings.TrimSpace(user.GatewayUserID) + } + if user.Source == "" || user.Source == "gateway" { + return strings.TrimSpace(user.ID) + } + return "" +} diff --git a/apps/api/internal/runner/task_cancel_test.go b/apps/api/internal/runner/task_cancel_test.go new file mode 100644 index 0000000..15b90aa --- /dev/null +++ b/apps/api/internal/runner/task_cancel_test.go @@ -0,0 +1,78 @@ +package runner + +import ( + "testing" + + "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" + "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" + "github.com/riverqueue/river/rivertype" +) + +func TestRiverJobStateCancellableOnlyAllowsLocalQueueStates(t *testing.T) { + cancellableStates := []rivertype.JobState{ + rivertype.JobStateAvailable, + rivertype.JobStateScheduled, + rivertype.JobStateRetryable, + rivertype.JobStatePending, + } + for _, state := range cancellableStates { + if !riverJobStateCancellable(state) { + t.Fatalf("expected %s to be cancellable", state) + } + } + + nonCancellableStates := []rivertype.JobState{ + rivertype.JobStateRunning, + rivertype.JobStateCancelled, + rivertype.JobStateCompleted, + rivertype.JobStateDiscarded, + } + for _, state := range nonCancellableStates { + if riverJobStateCancellable(state) { + t.Fatalf("expected %s to be non-cancellable", state) + } + } +} + +func TestTaskAccessibleToUserHonorsAPIKeyAndGatewayUser(t *testing.T) { + task := store.GatewayTask{ + UserID: "server-user", + GatewayUserID: "gateway-user", + APIKeyID: "api-key", + } + + if !taskAccessibleToUser(task, &auth.User{GatewayUserID: "gateway-user"}) { + t.Fatal("gateway user should access own task") + } + if !taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "gateway-user"}) { + t.Fatal("api key owner should access own task") + } + if taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "other-user"}) { + t.Fatal("api key from another gateway user should not access task") + } + if taskAccessibleToUser(task, &auth.User{GatewayUserID: "other-user"}) { + t.Fatal("another gateway user should not access task") + } + + serverMainTask := store.GatewayTask{ + UserID: "server-user", + APIKeyID: "server-api-key", + } + if !taskAccessibleToUser(serverMainTask, &auth.User{ID: "server-user", Source: "server-main", APIKeyID: "server-api-key"}) { + t.Fatal("server-main api key should fall back to user id when gateway user id is absent") + } + if taskAccessibleToUser(serverMainTask, &auth.User{ID: "other-user", Source: "server-main", APIKeyID: "server-api-key"}) { + t.Fatal("server-main api key from another user should not access task") + } +} + +func TestTaskCancelUnavailableReportsSubmittedFromRemoteTaskID(t *testing.T) { + result := taskCancelUnavailable(store.GatewayTask{ID: "task-1", RemoteTaskID: "remote-1"}, "不可取消") + + if result.Cancelled || result.Cancellable { + t.Fatalf("unavailable result should not mark cancellation as done: %+v", result) + } + if !result.Submitted { + t.Fatalf("remote task id should report submitted: %+v", result) + } +} diff --git a/apps/api/internal/store/tasks_runtime.go b/apps/api/internal/store/tasks_runtime.go index 423d271..f4009a3 100644 --- a/apps/api/internal/store/tasks_runtime.go +++ b/apps/api/internal/store/tasks_runtime.go @@ -256,6 +256,38 @@ WHERE id = $1::uuid`, }) } +func (s *Store) CancelQueuedTask(ctx context.Context, taskID string, message string) (GatewayTask, bool, error) { + message = strings.TrimSpace(message) + if message == "" { + message = "任务已取消" + } + tag, err := s.pool.Exec(ctx, ` +UPDATE gateway_tasks +SET status = 'cancelled', + error = NULLIF($2::text, ''), + error_code = 'task_cancelled', + error_message = NULLIF($2::text, ''), + locked_by = NULL, + locked_at = NULL, + heartbeat_at = NULL, + finished_at = now(), + updated_at = now() +WHERE id = $1::uuid + AND status = 'queued' + AND COALESCE(remote_task_id, '') = ''`, taskID, message) + if err != nil { + return GatewayTask{}, false, err + } + if tag.RowsAffected() == 0 { + return GatewayTask{}, false, nil + } + task, err := s.GetTask(ctx, taskID) + if err != nil { + return GatewayTask{}, true, err + } + return task, true, nil +} + func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) { if limit <= 0 { limit = 500