diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 7d7362774..718006c9c 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -459,11 +459,19 @@ def cancel_job( Returns the classification that was acted on (one of the CANCEL_* values), so callers can log or report what happened. + + For pending jobs the returned value reflects the *actual* dequeue result: + if the job left the queue between the caller's snapshot and the dequeue + call (a narrow TOCTOU window), the dequeue returns False and this function + returns CANCEL_UNKNOWN rather than CANCEL_PENDING, so callers that map the + return to a ``cancelled`` boolean never report a cancel that did not happen. """ classification = classify_job_for_cancel(prompt_id, running, queued, history) if classification == CANCEL_RUNNING: interrupt() elif classification == CANCEL_PENDING: - dequeue(prompt_id) + if not dequeue(prompt_id): + # Job was no longer in the queue by the time we tried to remove it. + return CANCEL_UNKNOWN # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. return classification diff --git a/server.py b/server.py index f5adb25c4..9591c9df7 100644 --- a/server.py +++ b/server.py @@ -976,6 +976,23 @@ class PromptServer(): status=400 ) + # Validate that every element is a well-formed job id before doing + # anything else. An unhashable element (e.g. a nested dict or list) + # would cause a TypeError when used as a history dict key; a + # non-string or non-UUID value is never a valid id. Reject early + # with 400 rather than letting the classify loop raise 500. + invalid_ids = [] + for jid in job_ids: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid if isinstance(jid, str) else repr(jid)) + if invalid_ids: + return web.json_response( + {"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400, + ) + # Validate every id exists before cancelling anything. A snapshot of # the queue + history is taken once so the membership check is # consistent for the whole batch. diff --git a/tests-unit/jobs_cancel_test/jobs_cancel_test.py b/tests-unit/jobs_cancel_test/jobs_cancel_test.py index ee88cefdd..372e8f64c 100644 --- a/tests-unit/jobs_cancel_test/jobs_cancel_test.py +++ b/tests-unit/jobs_cancel_test/jobs_cancel_test.py @@ -29,11 +29,19 @@ from comfy_execution.jobs import ( CANCEL_UNKNOWN, cancel_job, classify_job_for_cancel, + validate_job_id, ) # Classifications for which a cancel was actually dispatched (vs a no-op). _CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING) +# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format). +_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa" +_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb" +_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc" +_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd" +_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff" + def make_queue_item(prompt_id, number=0): """Build a queue tuple shaped like the real ones: index 1 is the id.""" @@ -113,6 +121,18 @@ def build_app(queue): if not isinstance(job_ids, list): return web.json_response({"error": "job_ids must be a list"}, status=400) + invalid_ids = [] + for jid in job_ids: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid if isinstance(jid, str) else repr(jid)) + if invalid_ids: + return web.json_response( + {"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400, + ) + running, pending = queue.get_current_queue() history = queue.get_history() unknown_ids = [ @@ -211,6 +231,23 @@ class TestCancelJobHelper: assert interrupts == [] assert dequeues == [] + def test_pending_dequeue_miss_returns_unknown(self): + """If dequeue returns False (job left queue between snapshot and delete), + cancel_job must return CANCEL_UNKNOWN so callers never report cancelled=True + for a cancel that did not actually happen (TOCTOU guard).""" + pending = [make_queue_item("b")] + interrupts = [] + dequeues = [] + # dequeue always returns False — simulates job already gone from queue + result = cancel_job( + "b", [], pending, {}, + interrupt=lambda: interrupts.append(True), + dequeue=lambda pid: (dequeues.append(pid), False)[1], + ) + assert result == CANCEL_UNKNOWN + assert dequeues == ["b"] # dequeue was attempted + assert interrupts == [] # interrupt was not called + # --------------------------------------------------------------------------- # HTTP contract tests: POST /api/jobs/{job_id}/cancel @@ -277,12 +314,12 @@ class TestBatchCancelEndpoint: @pytest.mark.asyncio async def test_batch_happy_path(self, aiohttp_client): queue = FakePromptQueue( - running=[make_queue_item("a")], - pending=[make_queue_item("b", number=1)], + running=[make_queue_item(_UUID_A)], + pending=[make_queue_item(_UUID_B, number=1)], ) client = await aiohttp_client(build_app(queue)) - resp = await client.post("/api/jobs/cancel", json={"job_ids": ["a", "b"]}) + resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]}) assert resp.status == 200 assert (await resp.json()) == {"cancelled": True} @@ -294,18 +331,18 @@ class TestBatchCancelEndpoint: self, aiohttp_client ): queue = FakePromptQueue( - running=[make_queue_item("a")], - pending=[make_queue_item("b", number=1)], + running=[make_queue_item(_UUID_A)], + pending=[make_queue_item(_UUID_B, number=1)], ) client = await aiohttp_client(build_app(queue)) resp = await client.post( - "/api/jobs/cancel", json={"job_ids": ["a", "missing", "b"]} + "/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]} ) assert resp.status == 404 body = await resp.json() - assert body["unknown_ids"] == ["missing"] + assert body["unknown_ids"] == [_UUID_MISSING] # Fail-fast: nothing was cancelled — no partial side effects. assert queue.interrupt_count == 0 assert len(queue.get_current_queue()[1]) == 1 @@ -313,13 +350,13 @@ class TestBatchCancelEndpoint: @pytest.mark.asyncio async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client): history = { - "c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}, - "d": {"prompt": make_queue_item("d"), "outputs": {}, "status": {}}, + _UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}}, + _UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}}, } queue = FakePromptQueue(history=history) client = await aiohttp_client(build_app(queue)) - resp = await client.post("/api/jobs/cancel", json={"job_ids": ["c", "d"]}) + resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]}) # All known but terminal: 200 with cancelled=false, nothing dispatched. assert resp.status == 200 @@ -334,3 +371,37 @@ class TestBatchCancelEndpoint: resp = await client.post("/api/jobs/cancel", json={}) assert resp.status == 400 + + @pytest.mark.asyncio + async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client): + """An unhashable element such as a dict or list must yield 400, not 500. + + Previously, passing e.g. {"job_ids": [{}]} would reach the classify + loop where ``prompt_id in history`` raises TypeError on an unhashable + type, resulting in an unhandled 500. The input-validation guard must + catch this before any queue or history access. + """ + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]}) + + assert resp.status == 400 + body = await resp.json() + assert "invalid_ids" in body + # No queue side effects. + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client): + """A string that is not a valid UUID must be rejected with 400.""" + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post( + "/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]} + ) + + assert resp.status == 400 + body = await resp.json() + assert "invalid_ids" in body