fix: resolve review feedback on cancel endpoints
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

- Guard cancel_job() against TOCTOU: when dequeue() returns False the
  pending job left the queue between snapshot and delete; return
  CANCEL_UNKNOWN so callers never report cancelled=True for a remove
  that did not happen.
- Validate each job_ids element in the batch cancel endpoint before
  any queue access; unhashable or non-UUID values now return 400
  instead of raising TypeError (500).
- Update batch HTTP tests to use canonical UUID ids (required now that
  the endpoint validates id format) and add tests for the new guards.
This commit is contained in:
Matt Miller 2026-06-18 11:04:09 -07:00
parent 7226d5890e
commit dabe0d56a4
3 changed files with 107 additions and 11 deletions

View File

@ -459,11 +459,19 @@ def cancel_job(
Returns the classification that was acted on (one of the CANCEL_* values), Returns the classification that was acted on (one of the CANCEL_* values),
so callers can log or report what happened. so callers can log or report what happened.
For pending jobs the returned value reflects the *actual* dequeue result:
if the job left the queue between the caller's snapshot and the dequeue
call (a narrow TOCTOU window), the dequeue returns False and this function
returns CANCEL_UNKNOWN rather than CANCEL_PENDING, so callers that map the
return to a ``cancelled`` boolean never report a cancel that did not happen.
""" """
classification = classify_job_for_cancel(prompt_id, running, queued, history) classification = classify_job_for_cancel(prompt_id, running, queued, history)
if classification == CANCEL_RUNNING: if classification == CANCEL_RUNNING:
interrupt() interrupt()
elif classification == CANCEL_PENDING: elif classification == CANCEL_PENDING:
dequeue(prompt_id) if not dequeue(prompt_id):
# Job was no longer in the queue by the time we tried to remove it.
return CANCEL_UNKNOWN
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
return classification return classification

View File

@ -976,6 +976,23 @@ class PromptServer():
status=400 status=400
) )
# Validate that every element is a well-formed job id before doing
# anything else. An unhashable element (e.g. a nested dict or list)
# would cause a TypeError when used as a history dict key; a
# non-string or non-UUID value is never a valid id. Reject early
# with 400 rather than letting the classify loop raise 500.
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
# Validate every id exists before cancelling anything. A snapshot of # Validate every id exists before cancelling anything. A snapshot of
# the queue + history is taken once so the membership check is # the queue + history is taken once so the membership check is
# consistent for the whole batch. # consistent for the whole batch.

View File

@ -29,11 +29,19 @@ from comfy_execution.jobs import (
CANCEL_UNKNOWN, CANCEL_UNKNOWN,
cancel_job, cancel_job,
classify_job_for_cancel, classify_job_for_cancel,
validate_job_id,
) )
# Classifications for which a cancel was actually dispatched (vs a no-op). # Classifications for which a cancel was actually dispatched (vs a no-op).
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING) _CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format).
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd"
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
def make_queue_item(prompt_id, number=0): def make_queue_item(prompt_id, number=0):
"""Build a queue tuple shaped like the real ones: index 1 is the id.""" """Build a queue tuple shaped like the real ones: index 1 is the id."""
@ -113,6 +121,18 @@ def build_app(queue):
if not isinstance(job_ids, list): if not isinstance(job_ids, list):
return web.json_response({"error": "job_ids must be a list"}, status=400) return web.json_response({"error": "job_ids must be a list"}, status=400)
invalid_ids = []
for jid in job_ids:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
if invalid_ids:
return web.json_response(
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400,
)
running, pending = queue.get_current_queue() running, pending = queue.get_current_queue()
history = queue.get_history() history = queue.get_history()
unknown_ids = [ unknown_ids = [
@ -211,6 +231,23 @@ class TestCancelJobHelper:
assert interrupts == [] assert interrupts == []
assert dequeues == [] assert dequeues == []
def test_pending_dequeue_miss_returns_unknown(self):
"""If dequeue returns False (job left queue between snapshot and delete),
cancel_job must return CANCEL_UNKNOWN so callers never report cancelled=True
for a cancel that did not actually happen (TOCTOU guard)."""
pending = [make_queue_item("b")]
interrupts = []
dequeues = []
# dequeue always returns False — simulates job already gone from queue
result = cancel_job(
"b", [], pending, {},
interrupt=lambda: interrupts.append(True),
dequeue=lambda pid: (dequeues.append(pid), False)[1],
)
assert result == CANCEL_UNKNOWN
assert dequeues == ["b"] # dequeue was attempted
assert interrupts == [] # interrupt was not called
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# HTTP contract tests: POST /api/jobs/{job_id}/cancel # HTTP contract tests: POST /api/jobs/{job_id}/cancel
@ -277,12 +314,12 @@ class TestBatchCancelEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batch_happy_path(self, aiohttp_client): async def test_batch_happy_path(self, aiohttp_client):
queue = FakePromptQueue( queue = FakePromptQueue(
running=[make_queue_item("a")], running=[make_queue_item(_UUID_A)],
pending=[make_queue_item("b", number=1)], pending=[make_queue_item(_UUID_B, number=1)],
) )
client = await aiohttp_client(build_app(queue)) client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": ["a", "b"]}) resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]})
assert resp.status == 200 assert resp.status == 200
assert (await resp.json()) == {"cancelled": True} assert (await resp.json()) == {"cancelled": True}
@ -294,18 +331,18 @@ class TestBatchCancelEndpoint:
self, aiohttp_client self, aiohttp_client
): ):
queue = FakePromptQueue( queue = FakePromptQueue(
running=[make_queue_item("a")], running=[make_queue_item(_UUID_A)],
pending=[make_queue_item("b", number=1)], pending=[make_queue_item(_UUID_B, number=1)],
) )
client = await aiohttp_client(build_app(queue)) client = await aiohttp_client(build_app(queue))
resp = await client.post( resp = await client.post(
"/api/jobs/cancel", json={"job_ids": ["a", "missing", "b"]} "/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
) )
assert resp.status == 404 assert resp.status == 404
body = await resp.json() body = await resp.json()
assert body["unknown_ids"] == ["missing"] assert body["unknown_ids"] == [_UUID_MISSING]
# Fail-fast: nothing was cancelled — no partial side effects. # Fail-fast: nothing was cancelled — no partial side effects.
assert queue.interrupt_count == 0 assert queue.interrupt_count == 0
assert len(queue.get_current_queue()[1]) == 1 assert len(queue.get_current_queue()[1]) == 1
@ -313,13 +350,13 @@ class TestBatchCancelEndpoint:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client): async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
history = { history = {
"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}, _UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}},
"d": {"prompt": make_queue_item("d"), "outputs": {}, "status": {}}, _UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}},
} }
queue = FakePromptQueue(history=history) queue = FakePromptQueue(history=history)
client = await aiohttp_client(build_app(queue)) client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": ["c", "d"]}) resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]})
# All known but terminal: 200 with cancelled=false, nothing dispatched. # All known but terminal: 200 with cancelled=false, nothing dispatched.
assert resp.status == 200 assert resp.status == 200
@ -334,3 +371,37 @@ class TestBatchCancelEndpoint:
resp = await client.post("/api/jobs/cancel", json={}) resp = await client.post("/api/jobs/cancel", json={})
assert resp.status == 400 assert resp.status == 400
@pytest.mark.asyncio
async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client):
"""An unhashable element such as a dict or list must yield 400, not 500.
Previously, passing e.g. {"job_ids": [{}]} would reach the classify
loop where ``prompt_id in history`` raises TypeError on an unhashable
type, resulting in an unhandled 500. The input-validation guard must
catch this before any queue or history access.
"""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]})
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body
# No queue side effects.
assert queue.interrupt_count == 0
@pytest.mark.asyncio
async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client):
"""A string that is not a valid UUID must be rejected with 400."""
queue = FakePromptQueue()
client = await aiohttp_client(build_app(queue))
resp = await client.post(
"/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]}
)
assert resp.status == 400
body = await resp.json()
assert "invalid_ids" in body