feat(api): 支持取消本地排队任务
This commit is contained in:
parent
a8fa8dd212
commit
e8df26da9b
@ -1452,6 +1452,7 @@ func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
|
||||
// @Failure 500 {object} ErrorEnvelope
|
||||
// @Router /api/workspace/tasks [get]
|
||||
// @Router /api/v1/tasks [get]
|
||||
// @Router /tasks [get]
|
||||
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := auth.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
@ -1560,6 +1561,7 @@ func boolValue(body map[string]any, key string) bool {
|
||||
// @Failure 500 {object} ErrorEnvelope
|
||||
// @Router /api/workspace/tasks/{taskID} [get]
|
||||
// @Router /api/v1/tasks/{taskID} [get]
|
||||
// @Router /tasks/{taskID} [get]
|
||||
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
||||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||||
if err == nil {
|
||||
@ -1574,6 +1576,45 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
||||
writeError(w, http.StatusInternalServerError, "get task failed")
|
||||
}
|
||||
|
||||
// cancelTask godoc
|
||||
// @Summary 取消异步任务
|
||||
// @Description 仅取消仍在网关本地排队且尚未提交上游的任务;任务已运行、已提交上游或已结束时返回不可取消,客户端应继续查询结果。
|
||||
// @Tags tasks
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param taskID path string true "任务 ID"
|
||||
// @Success 200 {object} TaskCancelResponse
|
||||
// @Failure 401 {object} ErrorEnvelope
|
||||
// @Failure 403 {object} ErrorEnvelope
|
||||
// @Failure 404 {object} ErrorEnvelope
|
||||
// @Failure 500 {object} ErrorEnvelope
|
||||
// @Router /api/workspace/tasks/{taskID}/cancel [post]
|
||||
// @Router /api/v1/tasks/{taskID}/cancel [post]
|
||||
// @Router /v1/tasks/{taskID}/cancel [post]
|
||||
// @Router /tasks/{taskID}/cancel [post]
|
||||
func (s *Server) cancelTask(w http.ResponseWriter, r *http.Request) {
|
||||
user, ok := auth.UserFromContext(r.Context())
|
||||
if !ok {
|
||||
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
result, err := s.runner.CancelTask(r.Context(), r.PathValue("taskID"), user)
|
||||
if err == nil {
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
return
|
||||
}
|
||||
if store.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "task not found")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, runner.ErrTaskAccessDenied) {
|
||||
writeError(w, http.StatusForbidden, "task access denied")
|
||||
return
|
||||
}
|
||||
s.logger.Error("cancel task failed", "error", err)
|
||||
writeError(w, http.StatusInternalServerError, "cancel task failed")
|
||||
}
|
||||
|
||||
// taskParamPreprocessing godoc
|
||||
// @Summary 获取任务参数预处理日志
|
||||
// @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。
|
||||
@ -1587,6 +1628,7 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
||||
// @Failure 500 {object} ErrorEnvelope
|
||||
// @Router /api/workspace/tasks/{taskID}/param-preprocessing [get]
|
||||
// @Router /api/v1/tasks/{taskID}/param-preprocessing [get]
|
||||
// @Router /tasks/{taskID}/param-preprocessing [get]
|
||||
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
|
||||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||||
if err != nil {
|
||||
@ -1620,6 +1662,7 @@ func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request)
|
||||
// @Failure 500 {object} ErrorEnvelope
|
||||
// @Router /api/workspace/tasks/{taskID}/events [get]
|
||||
// @Router /api/v1/tasks/{taskID}/events [get]
|
||||
// @Router /tasks/{taskID}/events [get]
|
||||
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
|
||||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||||
if err != nil {
|
||||
|
||||
@ -146,6 +146,14 @@ type TaskAcceptedResponse struct {
|
||||
Next TaskNextLinks `json:"next"`
|
||||
}
|
||||
|
||||
type TaskCancelResponse struct {
|
||||
TaskID string `json:"taskId" example:"9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
|
||||
Cancelled bool `json:"cancelled" example:"false"`
|
||||
Cancellable bool `json:"cancellable" example:"false"`
|
||||
Submitted bool `json:"submitted" example:"true"`
|
||||
Message string `json:"message" example:"任务已提交上游,当前不可取消,请继续查询结果"`
|
||||
}
|
||||
|
||||
type TaskNextLinks struct {
|
||||
Events string `json:"events" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25/events"`
|
||||
Detail string `json:"detail" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
|
||||
|
||||
@ -93,6 +93,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
||||
mux.Handle("GET /api/workspace/wallet/transactions", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listWalletTransactions)))
|
||||
mux.Handle("GET /api/workspace/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||
mux.Handle("GET /api/workspace/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||
mux.Handle("POST /api/workspace/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||
mux.Handle("GET /api/workspace/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
||||
mux.Handle("GET /api/workspace/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||
mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules)))
|
||||
@ -144,8 +145,14 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
||||
mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
|
||||
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||
mux.Handle("POST /api/v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||
mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
||||
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||
mux.Handle("GET /tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||
mux.Handle("GET /tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||
mux.Handle("POST /tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||
mux.Handle("GET /tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
||||
mux.Handle("GET /tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
||||
mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
||||
mux.Handle("POST /responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true)))
|
||||
@ -165,6 +172,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
||||
mux.Handle("POST /speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
|
||||
mux.Handle("POST /v1/speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
|
||||
mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
|
||||
mux.Handle("POST /v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||
|
||||
return server.recover(server.cors(mux))
|
||||
}
|
||||
|
||||
150
apps/api/internal/runner/task_cancel.go
Normal file
150
apps/api/internal/runner/task_cancel.go
Normal file
@ -0,0 +1,150 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
"github.com/riverqueue/river/rivertype"
|
||||
)
|
||||
|
||||
var ErrTaskAccessDenied = errors.New("task access denied")
|
||||
|
||||
type TaskCancelResult struct {
|
||||
TaskID string `json:"taskId"`
|
||||
Cancelled bool `json:"cancelled"`
|
||||
Cancellable bool `json:"cancellable"`
|
||||
Submitted bool `json:"submitted"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) {
|
||||
task, err := s.store.GetTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return TaskCancelResult{}, err
|
||||
}
|
||||
if !taskAccessibleToUser(task, user) {
|
||||
return TaskCancelResult{}, ErrTaskAccessDenied
|
||||
}
|
||||
if taskCancelTerminalStatus(task.Status) {
|
||||
return taskCancelUnavailable(task, "任务已结束,无法取消"), nil
|
||||
}
|
||||
if strings.TrimSpace(task.RemoteTaskID) != "" {
|
||||
return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil
|
||||
}
|
||||
if strings.TrimSpace(task.Status) != "queued" {
|
||||
return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil
|
||||
}
|
||||
if task.RiverJobID > 0 {
|
||||
if s.riverClient == nil {
|
||||
return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil
|
||||
}
|
||||
job, err := s.riverClient.JobGet(ctx, task.RiverJobID)
|
||||
if errors.Is(err, rivertype.ErrNotFound) {
|
||||
return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil
|
||||
}
|
||||
if err != nil {
|
||||
return TaskCancelResult{}, err
|
||||
}
|
||||
if job == nil || !riverJobStateCancellable(job.State) {
|
||||
return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil
|
||||
}
|
||||
if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil {
|
||||
if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) {
|
||||
return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil
|
||||
}
|
||||
return TaskCancelResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消")
|
||||
if err != nil {
|
||||
return TaskCancelResult{}, err
|
||||
}
|
||||
if !cancelled {
|
||||
latest, latestErr := s.store.GetTask(ctx, task.ID)
|
||||
if latestErr == nil {
|
||||
return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil
|
||||
}
|
||||
return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil
|
||||
}
|
||||
if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{
|
||||
"taskId": cancelledTask.ID,
|
||||
"reason": "manual_cancel",
|
||||
}, cancelledTask.RunMode == "simulation"); err != nil {
|
||||
return TaskCancelResult{}, err
|
||||
}
|
||||
return TaskCancelResult{
|
||||
TaskID: cancelledTask.ID,
|
||||
Cancelled: true,
|
||||
Cancellable: true,
|
||||
Submitted: false,
|
||||
Message: "任务已取消",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult {
|
||||
return TaskCancelResult{
|
||||
TaskID: task.ID,
|
||||
Cancelled: false,
|
||||
Cancellable: false,
|
||||
Submitted: strings.TrimSpace(task.RemoteTaskID) != "",
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
func taskCancelTerminalStatus(status string) bool {
|
||||
switch strings.TrimSpace(status) {
|
||||
case "succeeded", "failed", "cancelled":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func riverJobStateCancellable(state rivertype.JobState) bool {
|
||||
switch state {
|
||||
case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool {
|
||||
if user == nil {
|
||||
return false
|
||||
}
|
||||
if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" {
|
||||
if strings.TrimSpace(task.APIKeyID) != apiKeyID {
|
||||
return false
|
||||
}
|
||||
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
|
||||
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
|
||||
}
|
||||
if userID := strings.TrimSpace(user.ID); userID != "" {
|
||||
return strings.TrimSpace(task.UserID) == userID
|
||||
}
|
||||
return true
|
||||
}
|
||||
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
|
||||
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
|
||||
}
|
||||
userID := strings.TrimSpace(user.ID)
|
||||
return userID != "" && strings.TrimSpace(task.UserID) == userID
|
||||
}
|
||||
|
||||
func gatewayUserIDForAuth(user *auth.User) string {
|
||||
if user == nil {
|
||||
return ""
|
||||
}
|
||||
if user.GatewayUserID != "" {
|
||||
return strings.TrimSpace(user.GatewayUserID)
|
||||
}
|
||||
if user.Source == "" || user.Source == "gateway" {
|
||||
return strings.TrimSpace(user.ID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
78
apps/api/internal/runner/task_cancel_test.go
Normal file
78
apps/api/internal/runner/task_cancel_test.go
Normal file
@ -0,0 +1,78 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||
"github.com/riverqueue/river/rivertype"
|
||||
)
|
||||
|
||||
func TestRiverJobStateCancellableOnlyAllowsLocalQueueStates(t *testing.T) {
|
||||
cancellableStates := []rivertype.JobState{
|
||||
rivertype.JobStateAvailable,
|
||||
rivertype.JobStateScheduled,
|
||||
rivertype.JobStateRetryable,
|
||||
rivertype.JobStatePending,
|
||||
}
|
||||
for _, state := range cancellableStates {
|
||||
if !riverJobStateCancellable(state) {
|
||||
t.Fatalf("expected %s to be cancellable", state)
|
||||
}
|
||||
}
|
||||
|
||||
nonCancellableStates := []rivertype.JobState{
|
||||
rivertype.JobStateRunning,
|
||||
rivertype.JobStateCancelled,
|
||||
rivertype.JobStateCompleted,
|
||||
rivertype.JobStateDiscarded,
|
||||
}
|
||||
for _, state := range nonCancellableStates {
|
||||
if riverJobStateCancellable(state) {
|
||||
t.Fatalf("expected %s to be non-cancellable", state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskAccessibleToUserHonorsAPIKeyAndGatewayUser(t *testing.T) {
|
||||
task := store.GatewayTask{
|
||||
UserID: "server-user",
|
||||
GatewayUserID: "gateway-user",
|
||||
APIKeyID: "api-key",
|
||||
}
|
||||
|
||||
if !taskAccessibleToUser(task, &auth.User{GatewayUserID: "gateway-user"}) {
|
||||
t.Fatal("gateway user should access own task")
|
||||
}
|
||||
if !taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "gateway-user"}) {
|
||||
t.Fatal("api key owner should access own task")
|
||||
}
|
||||
if taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "other-user"}) {
|
||||
t.Fatal("api key from another gateway user should not access task")
|
||||
}
|
||||
if taskAccessibleToUser(task, &auth.User{GatewayUserID: "other-user"}) {
|
||||
t.Fatal("another gateway user should not access task")
|
||||
}
|
||||
|
||||
serverMainTask := store.GatewayTask{
|
||||
UserID: "server-user",
|
||||
APIKeyID: "server-api-key",
|
||||
}
|
||||
if !taskAccessibleToUser(serverMainTask, &auth.User{ID: "server-user", Source: "server-main", APIKeyID: "server-api-key"}) {
|
||||
t.Fatal("server-main api key should fall back to user id when gateway user id is absent")
|
||||
}
|
||||
if taskAccessibleToUser(serverMainTask, &auth.User{ID: "other-user", Source: "server-main", APIKeyID: "server-api-key"}) {
|
||||
t.Fatal("server-main api key from another user should not access task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskCancelUnavailableReportsSubmittedFromRemoteTaskID(t *testing.T) {
|
||||
result := taskCancelUnavailable(store.GatewayTask{ID: "task-1", RemoteTaskID: "remote-1"}, "不可取消")
|
||||
|
||||
if result.Cancelled || result.Cancellable {
|
||||
t.Fatalf("unavailable result should not mark cancellation as done: %+v", result)
|
||||
}
|
||||
if !result.Submitted {
|
||||
t.Fatalf("remote task id should report submitted: %+v", result)
|
||||
}
|
||||
}
|
||||
@ -256,6 +256,38 @@ WHERE id = $1::uuid`,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Store) CancelQueuedTask(ctx context.Context, taskID string, message string) (GatewayTask, bool, error) {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
message = "任务已取消"
|
||||
}
|
||||
tag, err := s.pool.Exec(ctx, `
|
||||
UPDATE gateway_tasks
|
||||
SET status = 'cancelled',
|
||||
error = NULLIF($2::text, ''),
|
||||
error_code = 'task_cancelled',
|
||||
error_message = NULLIF($2::text, ''),
|
||||
locked_by = NULL,
|
||||
locked_at = NULL,
|
||||
heartbeat_at = NULL,
|
||||
finished_at = now(),
|
||||
updated_at = now()
|
||||
WHERE id = $1::uuid
|
||||
AND status = 'queued'
|
||||
AND COALESCE(remote_task_id, '') = ''`, taskID, message)
|
||||
if err != nil {
|
||||
return GatewayTask{}, false, err
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return GatewayTask{}, false, nil
|
||||
}
|
||||
task, err := s.GetTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return GatewayTask{}, true, err
|
||||
}
|
||||
return task, true, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
|
||||
if limit <= 0 {
|
||||
limit = 500
|
||||
|
||||
Loading…
Reference in New Issue
Block a user