diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 718006c9c..fa3ab0faf 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -445,7 +445,7 @@ def cancel_job( running: list, queued: list, history: dict, - interrupt: Callable[[], None], + interrupt: Callable[[str], bool], dequeue: Callable[[str], bool], ) -> str: """Cancel a single job by id, regardless of state. @@ -457,21 +457,32 @@ def cancel_job( - an unknown id is a no-op (callers that need fail-fast behaviour should validate ids up front with ``classify_job_for_cancel``) - Returns the classification that was acted on (one of the CANCEL_* values), - so callers can log or report what happened. + Both ``interrupt`` and ``dequeue`` take the prompt id and return whether + they acted on a job that was *actually* in that state, so the value returned + here reflects what truly happened rather than the (possibly stale) + classification. This matters around the narrow TOCTOU windows where a job + changes state between the caller's snapshot and the action: - For pending jobs the returned value reflects the *actual* dequeue result: - if the job left the queue between the caller's snapshot and the dequeue - call (a narrow TOCTOU window), the dequeue returns False and this function - returns CANCEL_UNKNOWN rather than CANCEL_PENDING, so callers that map the - return to a ``cancelled`` boolean never report a cancel that did not happen. + - a job classified RUNNING may have finished before ``interrupt`` fires: + ``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op). + - a job classified PENDING may have started executing before ``dequeue`` + fires: ``dequeue`` returns False, ``interrupt`` then catches the now- + running job and this returns CANCEL_RUNNING. If it had simply finished + instead, both return False and this returns CANCEL_UNKNOWN. + + ``interrupt`` must be atomic — interrupt the job only if it is still the one + running — so a cancel can never land on an unrelated prompt that started in + the meantime (see ``execution.PromptQueue.interrupt_if_running``). """ classification = classify_job_for_cancel(prompt_id, running, queued, history) if classification == CANCEL_RUNNING: - interrupt() - elif classification == CANCEL_PENDING: - if not dequeue(prompt_id): - # Job was no longer in the queue by the time we tried to remove it. - return CANCEL_UNKNOWN + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + if classification == CANCEL_PENDING: + if dequeue(prompt_id): + return CANCEL_PENDING + # Left the pending queue between classification and dequeue: if it + # started executing, interrupt the now-running job; otherwise it has + # already finished and the cancel is a genuine no-op. + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. return classification diff --git a/execution.py b/execution.py index 9e16e451d..c45317593 100644 --- a/execution.py +++ b/execution.py @@ -1308,6 +1308,25 @@ class PromptQueue: queued = copy.copy(self.queue) return (running, queued) + def interrupt_if_running(self, prompt_id): + """Interrupt the running prompt with this id, atomically. + + Checks the live running set and signals the interrupt under the queue + mutex, so the worker cannot move the job to done (and start the next + prompt) in between. Returns True if a matching job was running and an + interrupt was signalled, False otherwise. The atomicity is what keeps a + cancel from landing on an unrelated prompt that started after a separate + is-running check: the global interrupt flag is reset at the start of + every prompt (execute_async), so a job that finishes before consuming + the flag cannot leak the interrupt onto its successor. + """ + with self.mutex: + for item in self.currently_running.values(): + if item[1] == prompt_id: + nodes.interrupt_processing() + return True + return False + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/server.py b/server.py index 9591c9df7..361850f38 100644 --- a/server.py +++ b/server.py @@ -14,10 +14,8 @@ from comfy_execution.jobs import ( get_all_jobs, validate_job_id, cancel_job, - classify_job_for_cancel, CANCEL_PENDING, CANCEL_RUNNING, - CANCEL_UNKNOWN, ) import uuid import urllib @@ -922,9 +920,12 @@ class PromptServer(): running, queued = self.prompt_queue.get_current_queue() history = self.prompt_queue.get_history() - def interrupt(): - logging.info(f"Cancelling running prompt {job_id}") - nodes.interrupt_processing() + def interrupt(prompt_id): + logging.info(f"Cancelling running prompt {prompt_id}") + # Atomic: only interrupts if the job is still the one running, + # so a cancel can't land on a prompt that started in the gap + # since the snapshot above. Returns whether it actually fired. + return self.prompt_queue.interrupt_if_running(prompt_id) def dequeue(prompt_id): logging.info(f"Cancelling pending prompt {prompt_id}") @@ -957,9 +958,13 @@ class PromptServer(): Body: {"job_ids": ["", ...]} - Fail-fast: if any provided id is unknown (not running, pending, or - in history) the request returns 404 and no job is cancelled, so the - call has no partial side effects. + Best-effort and idempotent: every well-formed id is cancelled if it + is running or pending; ids that are already finished or unknown are + no-ops, not errors. A batch of all no-ops still returns 200 with + {"cancelled": false}. This matches the single-cancel endpoint and + means "cancel all" still cancels the in-progress jobs even if some + finished between the client's snapshot and the request. Malformed + ids are still rejected up front with 400 (see below). """ try: json_data = await request.json() @@ -993,22 +998,9 @@ class PromptServer(): status=400, ) - # Validate every id exists before cancelling anything. A snapshot of - # the queue + history is taken once so the membership check is - # consistent for the whole batch. - running, queued = self.prompt_queue.get_current_queue() - history = self.prompt_queue.get_history() - - unknown_ids = [ - jid for jid in job_ids - if classify_job_for_cancel(jid, running, queued, history) == CANCEL_UNKNOWN - ] - if unknown_ids: - return web.json_response( - {"error": "Job(s) not found", "unknown_ids": unknown_ids}, - status=404 - ) - + # Best-effort: cancel each id that is still running/pending; an id + # that has finished or never existed is a no-op rather than a reason + # to fail the whole batch. cancelled = False for jid in job_ids: if _cancel_job_by_id(jid): diff --git a/tests-unit/jobs_cancel_test/jobs_cancel_test.py b/tests-unit/jobs_cancel_test/jobs_cancel_test.py index 372e8f64c..f1d591b0d 100644 --- a/tests-unit/jobs_cancel_test/jobs_cancel_test.py +++ b/tests-unit/jobs_cancel_test/jobs_cancel_test.py @@ -8,7 +8,8 @@ Covers both layers: * the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and ``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and - batch fail-fast on an unknown id with no partial side effects). + best-effort batch cancellation that treats unknown/finished ids as no-ops + while still rejecting malformed ids with 400). The HTTP layer is exercised against a small aiohttp app whose handlers are a faithful copy of the wiring in ``server.py`` driven by a fake queue that @@ -77,8 +78,13 @@ class FakePromptQueue: return True return False - def interrupt_processing(self): - self.interrupt_count += 1 + def interrupt_if_running(self, prompt_id): + # Mirrors execution.PromptQueue.interrupt_if_running: only signals an + # interrupt when the id is actually in the running set. + if any(item[1] == prompt_id for item in self._running): + self.interrupt_count += 1 + return True + return False def build_app(queue): @@ -91,8 +97,8 @@ def build_app(queue): running, pending = queue.get_current_queue() history = queue.get_history() - def interrupt(): - queue.interrupt_processing() + def interrupt(prompt_id): + return queue.interrupt_if_running(prompt_id) def dequeue(prompt_id): return queue.delete_queue_item(lambda a: a[1] == prompt_id) @@ -133,18 +139,6 @@ def build_app(queue): status=400, ) - running, pending = queue.get_current_queue() - history = queue.get_history() - unknown_ids = [ - jid - for jid in job_ids - if classify_job_for_cancel(jid, running, pending, history) == CANCEL_UNKNOWN - ] - if unknown_ids: - return web.json_response( - {"error": "Job(s) not found", "unknown_ids": unknown_ids}, status=404 - ) - cancelled = False for jid in job_ids: if _cancel_job_by_id(jid): @@ -180,26 +174,27 @@ class TestClassifyJobForCancel: class TestCancelJobHelper: + """``interrupt`` and ``dequeue`` both take the id and return whether they + actually acted, so cancel_job's return reflects the real outcome.""" + def test_running_is_interrupted_not_dequeued(self): - running = [make_queue_item("a")] interrupts = [] dequeues = [] result = cancel_job( - "a", running, [], {}, - interrupt=lambda: interrupts.append(True), + "a", [make_queue_item("a")], [], {}, + interrupt=lambda pid: interrupts.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True, ) assert result == CANCEL_RUNNING - assert interrupts == [True] + assert interrupts == ["a"] assert dequeues == [] def test_pending_is_dequeued_not_interrupted(self): - pending = [make_queue_item("b")] interrupts = [] dequeues = [] result = cancel_job( - "b", [], pending, {}, - interrupt=lambda: interrupts.append(True), + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True, ) assert result == CANCEL_PENDING @@ -212,7 +207,7 @@ class TestCancelJobHelper: dequeues = [] result = cancel_job( "c", [], [], history, - interrupt=lambda: interrupts.append(True), + interrupt=lambda pid: interrupts.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True, ) assert result == CANCEL_TERMINAL @@ -224,29 +219,57 @@ class TestCancelJobHelper: dequeues = [] result = cancel_job( "z", [], [], {}, - interrupt=lambda: interrupts.append(True), + interrupt=lambda pid: interrupts.append(pid) or True, dequeue=lambda pid: dequeues.append(pid) or True, ) assert result == CANCEL_UNKNOWN assert interrupts == [] assert dequeues == [] - def test_pending_dequeue_miss_returns_unknown(self): - """If dequeue returns False (job left queue between snapshot and delete), - cancel_job must return CANCEL_UNKNOWN so callers never report cancelled=True - for a cancel that did not actually happen (TOCTOU guard).""" - pending = [make_queue_item("b")] + def test_running_but_finished_before_interrupt_returns_unknown(self): + """Classified RUNNING from a stale snapshot, but the job finished before + the atomic interrupt fired (interrupt returns False). cancel_job reports + UNKNOWN rather than claiming a cancel that did not happen — and the + atomic interrupt guarantees no unrelated job was hit.""" + interrupts = [] + result = cancel_job( + "a", [make_queue_item("a")], [], {}, + interrupt=lambda pid: interrupts.append(pid) or False, + dequeue=lambda pid: True, + ) + assert result == CANCEL_UNKNOWN + assert interrupts == ["a"] # interrupt was attempted atomically + + def test_pending_started_running_is_interrupted(self): + """Pending->running race: the job leaves the queue (dequeue False) + because it started executing. The atomic interrupt catches the now- + running job, so cancel_job interrupts it and reports CANCEL_RUNNING.""" interrupts = [] dequeues = [] - # dequeue always returns False — simulates job already gone from queue result = cancel_job( - "b", [], pending, {}, - interrupt=lambda: interrupts.append(True), + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or True, + dequeue=lambda pid: (dequeues.append(pid), False)[1], + ) + assert result == CANCEL_RUNNING + assert dequeues == ["b"] # dequeue attempted first + assert interrupts == ["b"] # then the now-running job was interrupted + + def test_pending_dequeue_miss_not_running_returns_unknown(self): + """Dequeue miss where the job is not running anymore (it finished): the + atomic interrupt finds nothing to interrupt and returns False, so + cancel_job is a no-op reporting UNKNOWN — never reporting a cancel that + did not happen, and never interrupting a bystander.""" + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], [make_queue_item("b")], {}, + interrupt=lambda pid: interrupts.append(pid) or False, dequeue=lambda pid: (dequeues.append(pid), False)[1], ) assert result == CANCEL_UNKNOWN - assert dequeues == ["b"] # dequeue was attempted - assert interrupts == [] # interrupt was not called + assert dequeues == ["b"] + assert interrupts == ["b"] # interrupt attempted, found nothing running # --------------------------------------------------------------------------- @@ -304,6 +327,30 @@ class TestSingleCancelEndpoint: assert (await resp.json()) == {"cancelled": False} assert queue.interrupt_count == 0 + @pytest.mark.asyncio + async def test_cancel_pending_that_started_running_interrupts(self, aiohttp_client): + """Pending->running race end to end: the job is pending at snapshot time + but starts executing by the time we dequeue (delete misses). The live + re-check sees it running and interrupts it, so the cancel is not dropped + and the caller still gets cancelled=True.""" + + class RacingQueue(FakePromptQueue): + def delete_queue_item(self, function): + # The worker picked the job up just before we removed it: it + # leaves the pending queue (delete misses) and is now running. + self._running = list(self._pending) + self._pending = [] + return False + + queue = RacingQueue(pending=[make_queue_item("b")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/b/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 + # --------------------------------------------------------------------------- # HTTP contract tests: POST /api/jobs/cancel (batch) @@ -327,9 +374,10 @@ class TestBatchCancelEndpoint: assert queue.get_current_queue()[1] == [] # pending job dequeued @pytest.mark.asyncio - async def test_batch_fail_fast_404_on_unknown_id_no_side_effects( - self, aiohttp_client - ): + async def test_batch_best_effort_skips_unknown_id(self, aiohttp_client): + """An unknown id in the batch is a no-op, not a reason to abort: the + running and pending jobs are still cancelled (200, cancelled=true). This + is the "cancel all as a job finishes" case from review.""" queue = FakePromptQueue( running=[make_queue_item(_UUID_A)], pending=[make_queue_item(_UUID_B, number=1)], @@ -340,12 +388,10 @@ class TestBatchCancelEndpoint: "/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]} ) - assert resp.status == 404 - body = await resp.json() - assert body["unknown_ids"] == [_UUID_MISSING] - # Fail-fast: nothing was cancelled — no partial side effects. - assert queue.interrupt_count == 0 - assert len(queue.get_current_queue()[1]) == 1 + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 # running job interrupted + assert queue.get_current_queue()[1] == [] # pending job dequeued @pytest.mark.asyncio async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):