feat(api): 支持取消本地排队任务
This commit is contained in:
parent
a8fa8dd212
commit
e8df26da9b
@ -1452,6 +1452,7 @@ func matchedRateLimitRule(policy map[string]any, metric string) map[string]any {
|
|||||||
// @Failure 500 {object} ErrorEnvelope
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
// @Router /api/workspace/tasks [get]
|
// @Router /api/workspace/tasks [get]
|
||||||
// @Router /api/v1/tasks [get]
|
// @Router /api/v1/tasks [get]
|
||||||
|
// @Router /tasks [get]
|
||||||
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) listTasks(w http.ResponseWriter, r *http.Request) {
|
||||||
user, ok := auth.UserFromContext(r.Context())
|
user, ok := auth.UserFromContext(r.Context())
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -1560,6 +1561,7 @@ func boolValue(body map[string]any, key string) bool {
|
|||||||
// @Failure 500 {object} ErrorEnvelope
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
// @Router /api/workspace/tasks/{taskID} [get]
|
// @Router /api/workspace/tasks/{taskID} [get]
|
||||||
// @Router /api/v1/tasks/{taskID} [get]
|
// @Router /api/v1/tasks/{taskID} [get]
|
||||||
|
// @Router /tasks/{taskID} [get]
|
||||||
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
||||||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -1574,6 +1576,45 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeError(w, http.StatusInternalServerError, "get task failed")
|
writeError(w, http.StatusInternalServerError, "get task failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cancelTask godoc
|
||||||
|
// @Summary 取消异步任务
|
||||||
|
// @Description 仅取消仍在网关本地排队且尚未提交上游的任务;任务已运行、已提交上游或已结束时返回不可取消,客户端应继续查询结果。
|
||||||
|
// @Tags tasks
|
||||||
|
// @Produce json
|
||||||
|
// @Security BearerAuth
|
||||||
|
// @Param taskID path string true "任务 ID"
|
||||||
|
// @Success 200 {object} TaskCancelResponse
|
||||||
|
// @Failure 401 {object} ErrorEnvelope
|
||||||
|
// @Failure 403 {object} ErrorEnvelope
|
||||||
|
// @Failure 404 {object} ErrorEnvelope
|
||||||
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
|
// @Router /api/workspace/tasks/{taskID}/cancel [post]
|
||||||
|
// @Router /api/v1/tasks/{taskID}/cancel [post]
|
||||||
|
// @Router /v1/tasks/{taskID}/cancel [post]
|
||||||
|
// @Router /tasks/{taskID}/cancel [post]
|
||||||
|
func (s *Server) cancelTask(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user, ok := auth.UserFromContext(r.Context())
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, err := s.runner.CancelTask(r.Context(), r.PathValue("taskID"), user)
|
||||||
|
if err == nil {
|
||||||
|
writeJSON(w, http.StatusOK, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if store.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "task not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, runner.ErrTaskAccessDenied) {
|
||||||
|
writeError(w, http.StatusForbidden, "task access denied")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.logger.Error("cancel task failed", "error", err)
|
||||||
|
writeError(w, http.StatusInternalServerError, "cancel task failed")
|
||||||
|
}
|
||||||
|
|
||||||
// taskParamPreprocessing godoc
|
// taskParamPreprocessing godoc
|
||||||
// @Summary 获取任务参数预处理日志
|
// @Summary 获取任务参数预处理日志
|
||||||
// @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。
|
// @Description 返回指定任务在执行前的参数改写、校验或模板处理日志。
|
||||||
@ -1587,6 +1628,7 @@ func (s *Server) getTask(w http.ResponseWriter, r *http.Request) {
|
|||||||
// @Failure 500 {object} ErrorEnvelope
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
// @Router /api/workspace/tasks/{taskID}/param-preprocessing [get]
|
// @Router /api/workspace/tasks/{taskID}/param-preprocessing [get]
|
||||||
// @Router /api/v1/tasks/{taskID}/param-preprocessing [get]
|
// @Router /api/v1/tasks/{taskID}/param-preprocessing [get]
|
||||||
|
// @Router /tasks/{taskID}/param-preprocessing [get]
|
||||||
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request) {
|
||||||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1620,6 +1662,7 @@ func (s *Server) taskParamPreprocessing(w http.ResponseWriter, r *http.Request)
|
|||||||
// @Failure 500 {object} ErrorEnvelope
|
// @Failure 500 {object} ErrorEnvelope
|
||||||
// @Router /api/workspace/tasks/{taskID}/events [get]
|
// @Router /api/workspace/tasks/{taskID}/events [get]
|
||||||
// @Router /api/v1/tasks/{taskID}/events [get]
|
// @Router /api/v1/tasks/{taskID}/events [get]
|
||||||
|
// @Router /tasks/{taskID}/events [get]
|
||||||
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) taskEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
task, err := s.store.GetTask(r.Context(), r.PathValue("taskID"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -146,6 +146,14 @@ type TaskAcceptedResponse struct {
|
|||||||
Next TaskNextLinks `json:"next"`
|
Next TaskNextLinks `json:"next"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TaskCancelResponse struct {
|
||||||
|
TaskID string `json:"taskId" example:"9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
|
||||||
|
Cancelled bool `json:"cancelled" example:"false"`
|
||||||
|
Cancellable bool `json:"cancellable" example:"false"`
|
||||||
|
Submitted bool `json:"submitted" example:"true"`
|
||||||
|
Message string `json:"message" example:"任务已提交上游,当前不可取消,请继续查询结果"`
|
||||||
|
}
|
||||||
|
|
||||||
type TaskNextLinks struct {
|
type TaskNextLinks struct {
|
||||||
Events string `json:"events" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25/events"`
|
Events string `json:"events" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25/events"`
|
||||||
Detail string `json:"detail" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
|
Detail string `json:"detail" example:"/api/v1/tasks/9f4d8f3d-5f5f-4bb7-a4be-344a9f930e25"`
|
||||||
|
|||||||
@ -93,6 +93,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
|||||||
mux.Handle("GET /api/workspace/wallet/transactions", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listWalletTransactions)))
|
mux.Handle("GET /api/workspace/wallet/transactions", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listWalletTransactions)))
|
||||||
mux.Handle("GET /api/workspace/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
mux.Handle("GET /api/workspace/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||||
mux.Handle("GET /api/workspace/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
mux.Handle("GET /api/workspace/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||||
|
mux.Handle("POST /api/workspace/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||||
mux.Handle("GET /api/workspace/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
mux.Handle("GET /api/workspace/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
||||||
mux.Handle("GET /api/workspace/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
mux.Handle("GET /api/workspace/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||||
mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules)))
|
mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules)))
|
||||||
@ -144,8 +145,14 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
|||||||
mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
|
mux.Handle("POST /api/v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
|
||||||
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||||
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||||
|
mux.Handle("POST /api/v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||||
mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
mux.Handle("GET /api/v1/tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
||||||
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||||
|
mux.Handle("GET /tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
|
||||||
|
mux.Handle("GET /tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
|
||||||
|
mux.Handle("POST /tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||||
|
mux.Handle("GET /tasks/{taskID}/param-preprocessing", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskParamPreprocessing)))
|
||||||
|
mux.Handle("GET /tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
|
||||||
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
||||||
mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
|
||||||
mux.Handle("POST /responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true)))
|
mux.Handle("POST /responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true)))
|
||||||
@ -165,6 +172,7 @@ func NewServerWithContext(ctx context.Context, cfg config.Config, db *store.Stor
|
|||||||
mux.Handle("POST /speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
|
mux.Handle("POST /speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
|
||||||
mux.Handle("POST /v1/speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
|
mux.Handle("POST /v1/speech/generations", server.auth.Require(auth.PermissionBasic, server.createTask("speech.generations", true)))
|
||||||
mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
|
mux.Handle("POST /v1/files/upload", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.uploadFile)))
|
||||||
|
mux.Handle("POST /v1/tasks/{taskID}/cancel", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.cancelTask)))
|
||||||
|
|
||||||
return server.recover(server.cors(mux))
|
return server.recover(server.cors(mux))
|
||||||
}
|
}
|
||||||
|
|||||||
150
apps/api/internal/runner/task_cancel.go
Normal file
150
apps/api/internal/runner/task_cancel.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
"github.com/riverqueue/river/rivertype"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrTaskAccessDenied = errors.New("task access denied")
|
||||||
|
|
||||||
|
type TaskCancelResult struct {
|
||||||
|
TaskID string `json:"taskId"`
|
||||||
|
Cancelled bool `json:"cancelled"`
|
||||||
|
Cancellable bool `json:"cancellable"`
|
||||||
|
Submitted bool `json:"submitted"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) CancelTask(ctx context.Context, taskID string, user *auth.User) (TaskCancelResult, error) {
|
||||||
|
task, err := s.store.GetTask(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return TaskCancelResult{}, err
|
||||||
|
}
|
||||||
|
if !taskAccessibleToUser(task, user) {
|
||||||
|
return TaskCancelResult{}, ErrTaskAccessDenied
|
||||||
|
}
|
||||||
|
if taskCancelTerminalStatus(task.Status) {
|
||||||
|
return taskCancelUnavailable(task, "任务已结束,无法取消"), nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(task.RemoteTaskID) != "" {
|
||||||
|
return taskCancelUnavailable(task, "任务已提交上游,当前不可取消,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(task.Status) != "queued" {
|
||||||
|
return taskCancelUnavailable(task, "任务已开始执行,当前阶段不可取消,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
if task.RiverJobID > 0 {
|
||||||
|
if s.riverClient == nil {
|
||||||
|
return taskCancelUnavailable(task, "任务取消队列未就绪,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
job, err := s.riverClient.JobGet(ctx, task.RiverJobID)
|
||||||
|
if errors.Is(err, rivertype.ErrNotFound) {
|
||||||
|
return taskCancelUnavailable(task, "任务已不在本地排队队列,可能已提交上游,当前不可取消,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return TaskCancelResult{}, err
|
||||||
|
}
|
||||||
|
if job == nil || !riverJobStateCancellable(job.State) {
|
||||||
|
return taskCancelUnavailable(task, "任务已不在可取消队列状态,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
if _, err := s.riverClient.JobDelete(ctx, task.RiverJobID); err != nil {
|
||||||
|
if errors.Is(err, rivertype.ErrJobRunning) || errors.Is(err, rivertype.ErrNotFound) {
|
||||||
|
return taskCancelUnavailable(task, "任务已被工作进程领取,当前不可取消,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
return TaskCancelResult{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cancelledTask, cancelled, err := s.store.CancelQueuedTask(ctx, task.ID, "任务已取消")
|
||||||
|
if err != nil {
|
||||||
|
return TaskCancelResult{}, err
|
||||||
|
}
|
||||||
|
if !cancelled {
|
||||||
|
latest, latestErr := s.store.GetTask(ctx, task.ID)
|
||||||
|
if latestErr == nil {
|
||||||
|
return taskCancelUnavailable(latest, "任务状态已变化,当前不可取消,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
return taskCancelUnavailable(task, "任务状态已变化,当前不可取消,请继续查询结果"), nil
|
||||||
|
}
|
||||||
|
if err := s.emit(ctx, cancelledTask.ID, "task.cancelled", "cancelled", "cancelled", 1, "任务已取消", map[string]any{
|
||||||
|
"taskId": cancelledTask.ID,
|
||||||
|
"reason": "manual_cancel",
|
||||||
|
}, cancelledTask.RunMode == "simulation"); err != nil {
|
||||||
|
return TaskCancelResult{}, err
|
||||||
|
}
|
||||||
|
return TaskCancelResult{
|
||||||
|
TaskID: cancelledTask.ID,
|
||||||
|
Cancelled: true,
|
||||||
|
Cancellable: true,
|
||||||
|
Submitted: false,
|
||||||
|
Message: "任务已取消",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func taskCancelUnavailable(task store.GatewayTask, message string) TaskCancelResult {
|
||||||
|
return TaskCancelResult{
|
||||||
|
TaskID: task.ID,
|
||||||
|
Cancelled: false,
|
||||||
|
Cancellable: false,
|
||||||
|
Submitted: strings.TrimSpace(task.RemoteTaskID) != "",
|
||||||
|
Message: message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func taskCancelTerminalStatus(status string) bool {
|
||||||
|
switch strings.TrimSpace(status) {
|
||||||
|
case "succeeded", "failed", "cancelled":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func riverJobStateCancellable(state rivertype.JobState) bool {
|
||||||
|
switch state {
|
||||||
|
case rivertype.JobStateAvailable, rivertype.JobStateScheduled, rivertype.JobStateRetryable, rivertype.JobStatePending:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func taskAccessibleToUser(task store.GatewayTask, user *auth.User) bool {
|
||||||
|
if user == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if apiKeyID := strings.TrimSpace(user.APIKeyID); apiKeyID != "" {
|
||||||
|
if strings.TrimSpace(task.APIKeyID) != apiKeyID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
|
||||||
|
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
|
||||||
|
}
|
||||||
|
if userID := strings.TrimSpace(user.ID); userID != "" {
|
||||||
|
return strings.TrimSpace(task.UserID) == userID
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if gatewayUserID := gatewayUserIDForAuth(user); gatewayUserID != "" {
|
||||||
|
return strings.TrimSpace(task.GatewayUserID) == gatewayUserID
|
||||||
|
}
|
||||||
|
userID := strings.TrimSpace(user.ID)
|
||||||
|
return userID != "" && strings.TrimSpace(task.UserID) == userID
|
||||||
|
}
|
||||||
|
|
||||||
|
func gatewayUserIDForAuth(user *auth.User) string {
|
||||||
|
if user == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if user.GatewayUserID != "" {
|
||||||
|
return strings.TrimSpace(user.GatewayUserID)
|
||||||
|
}
|
||||||
|
if user.Source == "" || user.Source == "gateway" {
|
||||||
|
return strings.TrimSpace(user.ID)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
78
apps/api/internal/runner/task_cancel_test.go
Normal file
78
apps/api/internal/runner/task_cancel_test.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
|
||||||
|
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
|
||||||
|
"github.com/riverqueue/river/rivertype"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRiverJobStateCancellableOnlyAllowsLocalQueueStates(t *testing.T) {
|
||||||
|
cancellableStates := []rivertype.JobState{
|
||||||
|
rivertype.JobStateAvailable,
|
||||||
|
rivertype.JobStateScheduled,
|
||||||
|
rivertype.JobStateRetryable,
|
||||||
|
rivertype.JobStatePending,
|
||||||
|
}
|
||||||
|
for _, state := range cancellableStates {
|
||||||
|
if !riverJobStateCancellable(state) {
|
||||||
|
t.Fatalf("expected %s to be cancellable", state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nonCancellableStates := []rivertype.JobState{
|
||||||
|
rivertype.JobStateRunning,
|
||||||
|
rivertype.JobStateCancelled,
|
||||||
|
rivertype.JobStateCompleted,
|
||||||
|
rivertype.JobStateDiscarded,
|
||||||
|
}
|
||||||
|
for _, state := range nonCancellableStates {
|
||||||
|
if riverJobStateCancellable(state) {
|
||||||
|
t.Fatalf("expected %s to be non-cancellable", state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskAccessibleToUserHonorsAPIKeyAndGatewayUser(t *testing.T) {
|
||||||
|
task := store.GatewayTask{
|
||||||
|
UserID: "server-user",
|
||||||
|
GatewayUserID: "gateway-user",
|
||||||
|
APIKeyID: "api-key",
|
||||||
|
}
|
||||||
|
|
||||||
|
if !taskAccessibleToUser(task, &auth.User{GatewayUserID: "gateway-user"}) {
|
||||||
|
t.Fatal("gateway user should access own task")
|
||||||
|
}
|
||||||
|
if !taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "gateway-user"}) {
|
||||||
|
t.Fatal("api key owner should access own task")
|
||||||
|
}
|
||||||
|
if taskAccessibleToUser(task, &auth.User{APIKeyID: "api-key", GatewayUserID: "other-user"}) {
|
||||||
|
t.Fatal("api key from another gateway user should not access task")
|
||||||
|
}
|
||||||
|
if taskAccessibleToUser(task, &auth.User{GatewayUserID: "other-user"}) {
|
||||||
|
t.Fatal("another gateway user should not access task")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverMainTask := store.GatewayTask{
|
||||||
|
UserID: "server-user",
|
||||||
|
APIKeyID: "server-api-key",
|
||||||
|
}
|
||||||
|
if !taskAccessibleToUser(serverMainTask, &auth.User{ID: "server-user", Source: "server-main", APIKeyID: "server-api-key"}) {
|
||||||
|
t.Fatal("server-main api key should fall back to user id when gateway user id is absent")
|
||||||
|
}
|
||||||
|
if taskAccessibleToUser(serverMainTask, &auth.User{ID: "other-user", Source: "server-main", APIKeyID: "server-api-key"}) {
|
||||||
|
t.Fatal("server-main api key from another user should not access task")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskCancelUnavailableReportsSubmittedFromRemoteTaskID(t *testing.T) {
|
||||||
|
result := taskCancelUnavailable(store.GatewayTask{ID: "task-1", RemoteTaskID: "remote-1"}, "不可取消")
|
||||||
|
|
||||||
|
if result.Cancelled || result.Cancellable {
|
||||||
|
t.Fatalf("unavailable result should not mark cancellation as done: %+v", result)
|
||||||
|
}
|
||||||
|
if !result.Submitted {
|
||||||
|
t.Fatalf("remote task id should report submitted: %+v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -256,6 +256,38 @@ WHERE id = $1::uuid`,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) CancelQueuedTask(ctx context.Context, taskID string, message string) (GatewayTask, bool, error) {
|
||||||
|
message = strings.TrimSpace(message)
|
||||||
|
if message == "" {
|
||||||
|
message = "任务已取消"
|
||||||
|
}
|
||||||
|
tag, err := s.pool.Exec(ctx, `
|
||||||
|
UPDATE gateway_tasks
|
||||||
|
SET status = 'cancelled',
|
||||||
|
error = NULLIF($2::text, ''),
|
||||||
|
error_code = 'task_cancelled',
|
||||||
|
error_message = NULLIF($2::text, ''),
|
||||||
|
locked_by = NULL,
|
||||||
|
locked_at = NULL,
|
||||||
|
heartbeat_at = NULL,
|
||||||
|
finished_at = now(),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = $1::uuid
|
||||||
|
AND status = 'queued'
|
||||||
|
AND COALESCE(remote_task_id, '') = ''`, taskID, message)
|
||||||
|
if err != nil {
|
||||||
|
return GatewayTask{}, false, err
|
||||||
|
}
|
||||||
|
if tag.RowsAffected() == 0 {
|
||||||
|
return GatewayTask{}, false, nil
|
||||||
|
}
|
||||||
|
task, err := s.GetTask(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return GatewayTask{}, true, err
|
||||||
|
}
|
||||||
|
return task, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
|
func (s *Store) ListRecoverableAsyncTasks(ctx context.Context, limit int) ([]AsyncTaskQueueItem, error) {
|
||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 500
|
limit = 500
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user