fix: make job cancel atomic and best-effort

Addresses two cancel races/edges raised in review.

Targeted, atomic interrupt. cancel_job's interrupt callback now takes the
prompt id and returns whether it fired; the single-cancel route backs it
with the new PromptQueue.interrupt_if_running, which checks the running set
and signals the interrupt under the queue mutex. This closes the TOCTOU
where a pending job that starts executing between the snapshot and dequeue
(or a running job that finishes between the snapshot and interrupt) could be
missed or, worse, cause an unrelated prompt to be interrupted. The per-prompt
interrupt-flag reset in execute_async keeps a finished job from leaking the
interrupt onto its successor.

Best-effort batch cancel. POST /api/jobs/cancel no longer fails the whole
batch with 404 when one id is unknown/finished; such ids are treated as
no-ops, so "cancel all" still cancels the in-progress jobs even if some
finished between the client's snapshot and the request. Malformed ids are
still rejected with 400.
This commit is contained in:
Matt Miller 2026-06-19 16:18:39 -07:00
parent dee29e783e
commit 22b25fcd26
4 changed files with 150 additions and 82 deletions

View File

@ -445,7 +445,7 @@ def cancel_job(
running: list, running: list,
queued: list, queued: list,
history: dict, history: dict,
interrupt: Callable[[], None], interrupt: Callable[[str], bool],
dequeue: Callable[[str], bool], dequeue: Callable[[str], bool],
) -> str: ) -> str:
"""Cancel a single job by id, regardless of state. """Cancel a single job by id, regardless of state.
@ -457,21 +457,32 @@ def cancel_job(
- an unknown id is a no-op (callers that need fail-fast behaviour should - an unknown id is a no-op (callers that need fail-fast behaviour should
validate ids up front with ``classify_job_for_cancel``) validate ids up front with ``classify_job_for_cancel``)
Returns the classification that was acted on (one of the CANCEL_* values), Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
so callers can log or report what happened. they acted on a job that was *actually* in that state, so the value returned
here reflects what truly happened rather than the (possibly stale)
classification. This matters around the narrow TOCTOU windows where a job
changes state between the caller's snapshot and the action:
For pending jobs the returned value reflects the *actual* dequeue result: - a job classified RUNNING may have finished before ``interrupt`` fires:
if the job left the queue between the caller's snapshot and the dequeue ``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
call (a narrow TOCTOU window), the dequeue returns False and this function - a job classified PENDING may have started executing before ``dequeue``
returns CANCEL_UNKNOWN rather than CANCEL_PENDING, so callers that map the fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
return to a ``cancelled`` boolean never report a cancel that did not happen. running job and this returns CANCEL_RUNNING. If it had simply finished
instead, both return False and this returns CANCEL_UNKNOWN.
``interrupt`` must be atomic interrupt the job only if it is still the one
running so a cancel can never land on an unrelated prompt that started in
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
""" """
classification = classify_job_for_cancel(prompt_id, running, queued, history) classification = classify_job_for_cancel(prompt_id, running, queued, history)
if classification == CANCEL_RUNNING: if classification == CANCEL_RUNNING:
interrupt() return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
elif classification == CANCEL_PENDING: if classification == CANCEL_PENDING:
if not dequeue(prompt_id): if dequeue(prompt_id):
# Job was no longer in the queue by the time we tried to remove it. return CANCEL_PENDING
return CANCEL_UNKNOWN # Left the pending queue between classification and dequeue: if it
# started executing, interrupt the now-running job; otherwise it has
# already finished and the cancel is a genuine no-op.
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
return classification return classification

View File

@ -1308,6 +1308,25 @@ class PromptQueue:
queued = copy.copy(self.queue) queued = copy.copy(self.queue)
return (running, queued) return (running, queued)
def interrupt_if_running(self, prompt_id):
"""Interrupt the running prompt with this id, atomically.
Checks the live running set and signals the interrupt under the queue
mutex, so the worker cannot move the job to done (and start the next
prompt) in between. Returns True if a matching job was running and an
interrupt was signalled, False otherwise. The atomicity is what keeps a
cancel from landing on an unrelated prompt that started after a separate
is-running check: the global interrupt flag is reset at the start of
every prompt (execute_async), so a job that finishes before consuming
the flag cannot leak the interrupt onto its successor.
"""
with self.mutex:
for item in self.currently_running.values():
if item[1] == prompt_id:
nodes.interrupt_processing()
return True
return False
def get_tasks_remaining(self): def get_tasks_remaining(self):
with self.mutex: with self.mutex:
return len(self.queue) + len(self.currently_running) return len(self.queue) + len(self.currently_running)

View File

@ -14,10 +14,8 @@ from comfy_execution.jobs import (
get_all_jobs, get_all_jobs,
validate_job_id, validate_job_id,
cancel_job, cancel_job,
classify_job_for_cancel,
CANCEL_PENDING, CANCEL_PENDING,
CANCEL_RUNNING, CANCEL_RUNNING,
CANCEL_UNKNOWN,
) )
import uuid import uuid
import urllib import urllib
@ -922,9 +920,12 @@ class PromptServer():
running, queued = self.prompt_queue.get_current_queue() running, queued = self.prompt_queue.get_current_queue()
history = self.prompt_queue.get_history() history = self.prompt_queue.get_history()
def interrupt(): def interrupt(prompt_id):
logging.info(f"Cancelling running prompt {job_id}") logging.info(f"Cancelling running prompt {prompt_id}")
nodes.interrupt_processing() # Atomic: only interrupts if the job is still the one running,
# so a cancel can't land on a prompt that started in the gap
# since the snapshot above. Returns whether it actually fired.
return self.prompt_queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id): def dequeue(prompt_id):
logging.info(f"Cancelling pending prompt {prompt_id}") logging.info(f"Cancelling pending prompt {prompt_id}")
@ -957,9 +958,13 @@ class PromptServer():
Body: {"job_ids": ["<uuid>", ...]} Body: {"job_ids": ["<uuid>", ...]}
Fail-fast: if any provided id is unknown (not running, pending, or Best-effort and idempotent: every well-formed id is cancelled if it
in history) the request returns 404 and no job is cancelled, so the is running or pending; ids that are already finished or unknown are
call has no partial side effects. no-ops, not errors. A batch of all no-ops still returns 200 with
{"cancelled": false}. This matches the single-cancel endpoint and
means "cancel all" still cancels the in-progress jobs even if some
finished between the client's snapshot and the request. Malformed
ids are still rejected up front with 400 (see below).
""" """
try: try:
json_data = await request.json() json_data = await request.json()
@ -993,22 +998,9 @@ class PromptServer():
status=400, status=400,
) )
# Validate every id exists before cancelling anything. A snapshot of # Best-effort: cancel each id that is still running/pending; an id
# the queue + history is taken once so the membership check is # that has finished or never existed is a no-op rather than a reason
# consistent for the whole batch. # to fail the whole batch.
running, queued = self.prompt_queue.get_current_queue()
history = self.prompt_queue.get_history()
unknown_ids = [
jid for jid in job_ids
if classify_job_for_cancel(jid, running, queued, history) == CANCEL_UNKNOWN
]
if unknown_ids:
return web.json_response(
{"error": "Job(s) not found", "unknown_ids": unknown_ids},
status=404
)
cancelled = False cancelled = False
for jid in job_ids: for jid in job_ids:
if _cancel_job_by_id(jid): if _cancel_job_by_id(jid):

View File

@ -8,7 +8,8 @@ Covers both layers:
* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and * the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and
``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and ``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and
batch fail-fast on an unknown id with no partial side effects). best-effort batch cancellation that treats unknown/finished ids as no-ops
while still rejecting malformed ids with 400).
The HTTP layer is exercised against a small aiohttp app whose handlers are a The HTTP layer is exercised against a small aiohttp app whose handlers are a
faithful copy of the wiring in ``server.py`` driven by a fake queue that faithful copy of the wiring in ``server.py`` driven by a fake queue that
@ -77,8 +78,13 @@ class FakePromptQueue:
return True return True
return False return False
def interrupt_processing(self): def interrupt_if_running(self, prompt_id):
self.interrupt_count += 1 # Mirrors execution.PromptQueue.interrupt_if_running: only signals an
# interrupt when the id is actually in the running set.
if any(item[1] == prompt_id for item in self._running):
self.interrupt_count += 1
return True
return False
def build_app(queue): def build_app(queue):
@ -91,8 +97,8 @@ def build_app(queue):
running, pending = queue.get_current_queue() running, pending = queue.get_current_queue()
history = queue.get_history() history = queue.get_history()
def interrupt(): def interrupt(prompt_id):
queue.interrupt_processing() return queue.interrupt_if_running(prompt_id)
def dequeue(prompt_id): def dequeue(prompt_id):
return queue.delete_queue_item(lambda a: a[1] == prompt_id) return queue.delete_queue_item(lambda a: a[1] == prompt_id)
@ -133,18 +139,6 @@ def build_app(queue):
status=400, status=400,
) )
running, pending = queue.get_current_queue()
history = queue.get_history()
unknown_ids = [
jid
for jid in job_ids
if classify_job_for_cancel(jid, running, pending, history) == CANCEL_UNKNOWN
]
if unknown_ids:
return web.json_response(
{"error": "Job(s) not found", "unknown_ids": unknown_ids}, status=404
)
cancelled = False cancelled = False
for jid in job_ids: for jid in job_ids:
if _cancel_job_by_id(jid): if _cancel_job_by_id(jid):
@ -180,26 +174,27 @@ class TestClassifyJobForCancel:
class TestCancelJobHelper: class TestCancelJobHelper:
"""``interrupt`` and ``dequeue`` both take the id and return whether they
actually acted, so cancel_job's return reflects the real outcome."""
def test_running_is_interrupted_not_dequeued(self): def test_running_is_interrupted_not_dequeued(self):
running = [make_queue_item("a")]
interrupts = [] interrupts = []
dequeues = [] dequeues = []
result = cancel_job( result = cancel_job(
"a", running, [], {}, "a", [make_queue_item("a")], [], {},
interrupt=lambda: interrupts.append(True), interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True,
) )
assert result == CANCEL_RUNNING assert result == CANCEL_RUNNING
assert interrupts == [True] assert interrupts == ["a"]
assert dequeues == [] assert dequeues == []
def test_pending_is_dequeued_not_interrupted(self): def test_pending_is_dequeued_not_interrupted(self):
pending = [make_queue_item("b")]
interrupts = [] interrupts = []
dequeues = [] dequeues = []
result = cancel_job( result = cancel_job(
"b", [], pending, {}, "b", [], [make_queue_item("b")], {},
interrupt=lambda: interrupts.append(True), interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True,
) )
assert result == CANCEL_PENDING assert result == CANCEL_PENDING
@ -212,7 +207,7 @@ class TestCancelJobHelper:
dequeues = [] dequeues = []
result = cancel_job( result = cancel_job(
"c", [], [], history, "c", [], [], history,
interrupt=lambda: interrupts.append(True), interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True,
) )
assert result == CANCEL_TERMINAL assert result == CANCEL_TERMINAL
@ -224,29 +219,57 @@ class TestCancelJobHelper:
dequeues = [] dequeues = []
result = cancel_job( result = cancel_job(
"z", [], [], {}, "z", [], [], {},
interrupt=lambda: interrupts.append(True), interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: dequeues.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True,
) )
assert result == CANCEL_UNKNOWN assert result == CANCEL_UNKNOWN
assert interrupts == [] assert interrupts == []
assert dequeues == [] assert dequeues == []
def test_pending_dequeue_miss_returns_unknown(self): def test_running_but_finished_before_interrupt_returns_unknown(self):
"""If dequeue returns False (job left queue between snapshot and delete), """Classified RUNNING from a stale snapshot, but the job finished before
cancel_job must return CANCEL_UNKNOWN so callers never report cancelled=True the atomic interrupt fired (interrupt returns False). cancel_job reports
for a cancel that did not actually happen (TOCTOU guard).""" UNKNOWN rather than claiming a cancel that did not happen and the
pending = [make_queue_item("b")] atomic interrupt guarantees no unrelated job was hit."""
interrupts = []
result = cancel_job(
"a", [make_queue_item("a")], [], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: True,
)
assert result == CANCEL_UNKNOWN
assert interrupts == ["a"] # interrupt was attempted atomically
def test_pending_started_running_is_interrupted(self):
"""Pending->running race: the job leaves the queue (dequeue False)
because it started executing. The atomic interrupt catches the now-
running job, so cancel_job interrupts it and reports CANCEL_RUNNING."""
interrupts = [] interrupts = []
dequeues = [] dequeues = []
# dequeue always returns False — simulates job already gone from queue
result = cancel_job( result = cancel_job(
"b", [], pending, {}, "b", [], [make_queue_item("b")], {},
interrupt=lambda: interrupts.append(True), interrupt=lambda pid: interrupts.append(pid) or True,
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_RUNNING
assert dequeues == ["b"] # dequeue attempted first
assert interrupts == ["b"] # then the now-running job was interrupted
def test_pending_dequeue_miss_not_running_returns_unknown(self):
"""Dequeue miss where the job is not running anymore (it finished): the
atomic interrupt finds nothing to interrupt and returns False, so
cancel_job is a no-op reporting UNKNOWN never reporting a cancel that
did not happen, and never interrupting a bystander."""
interrupts = []
dequeues = []
result = cancel_job(
"b", [], [make_queue_item("b")], {},
interrupt=lambda pid: interrupts.append(pid) or False,
dequeue=lambda pid: (dequeues.append(pid), False)[1], dequeue=lambda pid: (dequeues.append(pid), False)[1],
) )
assert result == CANCEL_UNKNOWN assert result == CANCEL_UNKNOWN
assert dequeues == ["b"] # dequeue was attempted assert dequeues == ["b"]
assert interrupts == [] # interrupt was not called assert interrupts == ["b"] # interrupt attempted, found nothing running
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -304,6 +327,30 @@ class TestSingleCancelEndpoint:
assert (await resp.json()) == {"cancelled": False} assert (await resp.json()) == {"cancelled": False}
assert queue.interrupt_count == 0 assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client):
"""Pending->running race end to end: the job is pending at snapshot time
but starts executing by the time we dequeue (delete misses). The live
re-check sees it running and interrupts it, so the cancel is not dropped
and the caller still gets cancelled=True."""
class RacingQueue(FakePromptQueue):
def delete_queue_item(self, function):
# The worker picked the job up just before we removed it: it
# leaves the pending queue (delete misses) and is now running.
self._running = list(self._pending)
self._pending = []
return False
queue = RacingQueue(pending=[make_queue_item("b")])
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/b/cancel")
assert resp.status == 200
assert (await resp.json()) == {"cancelled": True}
assert queue.interrupt_count == 1
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/cancel (batch) # HTTP contract tests: POST /api/jobs/cancel (batch)
@ -327,9 +374,10 @@ class TestBatchCancelEndpoint:
assert queue.get_current_queue()[1] == [] # pending job dequeued assert queue.get_current_queue()[1] == [] # pending job dequeued
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batch_fail_fast_404_on_unknown_id_no_side_effects( async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client):
self, aiohttp_client """An unknown id in the batch is a no-op, not a reason to abort: the
): running and pending jobs are still cancelled (200, cancelled=true). This
is the "cancel all as a job finishes" case from review."""
queue = FakePromptQueue( queue = FakePromptQueue(
running=[make_queue_item(_UUID_A)], running=[make_queue_item(_UUID_A)],
pending=[make_queue_item(_UUID_B, number=1)], pending=[make_queue_item(_UUID_B, number=1)],
@ -340,12 +388,10 @@ class TestBatchCancelEndpoint:
"/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]} "/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
) )
assert resp.status == 404 assert resp.status == 200
body = await resp.json() assert (await resp.json()) == {"cancelled": True}
assert body["unknown_ids"] == [_UUID_MISSING] assert queue.interrupt_count == 1 # running job interrupted
# Fail-fast: nothing was cancelled — no partial side effects. assert queue.get_current_queue()[1] == [] # pending job dequeued
assert queue.interrupt_count == 0
assert len(queue.get_current_queue()[1]) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client): async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):