mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
fix: resolve review feedback on cancel endpoints
- Guard cancel_job() against TOCTOU: when dequeue() returns False the pending job left the queue between snapshot and delete; return CANCEL_UNKNOWN so callers never report cancelled=True for a remove that did not happen. - Validate each job_ids element in the batch cancel endpoint before any queue access; unhashable or non-UUID values now return 400 instead of raising TypeError (500). - Update batch HTTP tests to use canonical UUID ids (required now that the endpoint validates id format) and add tests for the new guards.
This commit is contained in:
parent
7226d5890e
commit
dabe0d56a4
@ -459,11 +459,19 @@ def cancel_job(
|
|||||||
|
|
||||||
Returns the classification that was acted on (one of the CANCEL_* values),
|
Returns the classification that was acted on (one of the CANCEL_* values),
|
||||||
so callers can log or report what happened.
|
so callers can log or report what happened.
|
||||||
|
|
||||||
|
For pending jobs the returned value reflects the *actual* dequeue result:
|
||||||
|
if the job left the queue between the caller's snapshot and the dequeue
|
||||||
|
call (a narrow TOCTOU window), the dequeue returns False and this function
|
||||||
|
returns CANCEL_UNKNOWN rather than CANCEL_PENDING, so callers that map the
|
||||||
|
return to a ``cancelled`` boolean never report a cancel that did not happen.
|
||||||
"""
|
"""
|
||||||
classification = classify_job_for_cancel(prompt_id, running, queued, history)
|
classification = classify_job_for_cancel(prompt_id, running, queued, history)
|
||||||
if classification == CANCEL_RUNNING:
|
if classification == CANCEL_RUNNING:
|
||||||
interrupt()
|
interrupt()
|
||||||
elif classification == CANCEL_PENDING:
|
elif classification == CANCEL_PENDING:
|
||||||
dequeue(prompt_id)
|
if not dequeue(prompt_id):
|
||||||
|
# Job was no longer in the queue by the time we tried to remove it.
|
||||||
|
return CANCEL_UNKNOWN
|
||||||
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
|
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
|
||||||
return classification
|
return classification
|
||||||
|
|||||||
17
server.py
17
server.py
@ -976,6 +976,23 @@ class PromptServer():
|
|||||||
status=400
|
status=400
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate that every element is a well-formed job id before doing
|
||||||
|
# anything else. An unhashable element (e.g. a nested dict or list)
|
||||||
|
# would cause a TypeError when used as a history dict key; a
|
||||||
|
# non-string or non-UUID value is never a valid id. Reject early
|
||||||
|
# with 400 rather than letting the classify loop raise 500.
|
||||||
|
invalid_ids = []
|
||||||
|
for jid in job_ids:
|
||||||
|
try:
|
||||||
|
validate_job_id(jid)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
|
||||||
|
if invalid_ids:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
# Validate every id exists before cancelling anything. A snapshot of
|
# Validate every id exists before cancelling anything. A snapshot of
|
||||||
# the queue + history is taken once so the membership check is
|
# the queue + history is taken once so the membership check is
|
||||||
# consistent for the whole batch.
|
# consistent for the whole batch.
|
||||||
|
|||||||
@ -29,11 +29,19 @@ from comfy_execution.jobs import (
|
|||||||
CANCEL_UNKNOWN,
|
CANCEL_UNKNOWN,
|
||||||
cancel_job,
|
cancel_job,
|
||||||
classify_job_for_cancel,
|
classify_job_for_cancel,
|
||||||
|
validate_job_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Classifications for which a cancel was actually dispatched (vs a no-op).
|
# Classifications for which a cancel was actually dispatched (vs a no-op).
|
||||||
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
|
_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING)
|
||||||
|
|
||||||
|
# Canonical UUID ids for HTTP-layer tests (the batch endpoint validates UUID format).
|
||||||
|
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
|
||||||
|
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
|
||||||
|
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
|
||||||
|
_UUID_D = "dddddddd-dddd-4ddd-dddd-dddddddddddd"
|
||||||
|
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
|
||||||
|
|
||||||
|
|
||||||
def make_queue_item(prompt_id, number=0):
|
def make_queue_item(prompt_id, number=0):
|
||||||
"""Build a queue tuple shaped like the real ones: index 1 is the id."""
|
"""Build a queue tuple shaped like the real ones: index 1 is the id."""
|
||||||
@ -113,6 +121,18 @@ def build_app(queue):
|
|||||||
if not isinstance(job_ids, list):
|
if not isinstance(job_ids, list):
|
||||||
return web.json_response({"error": "job_ids must be a list"}, status=400)
|
return web.json_response({"error": "job_ids must be a list"}, status=400)
|
||||||
|
|
||||||
|
invalid_ids = []
|
||||||
|
for jid in job_ids:
|
||||||
|
try:
|
||||||
|
validate_job_id(jid)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
invalid_ids.append(jid if isinstance(jid, str) else repr(jid))
|
||||||
|
if invalid_ids:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": "job_ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
running, pending = queue.get_current_queue()
|
running, pending = queue.get_current_queue()
|
||||||
history = queue.get_history()
|
history = queue.get_history()
|
||||||
unknown_ids = [
|
unknown_ids = [
|
||||||
@ -211,6 +231,23 @@ class TestCancelJobHelper:
|
|||||||
assert interrupts == []
|
assert interrupts == []
|
||||||
assert dequeues == []
|
assert dequeues == []
|
||||||
|
|
||||||
|
def test_pending_dequeue_miss_returns_unknown(self):
|
||||||
|
"""If dequeue returns False (job left queue between snapshot and delete),
|
||||||
|
cancel_job must return CANCEL_UNKNOWN so callers never report cancelled=True
|
||||||
|
for a cancel that did not actually happen (TOCTOU guard)."""
|
||||||
|
pending = [make_queue_item("b")]
|
||||||
|
interrupts = []
|
||||||
|
dequeues = []
|
||||||
|
# dequeue always returns False — simulates job already gone from queue
|
||||||
|
result = cancel_job(
|
||||||
|
"b", [], pending, {},
|
||||||
|
interrupt=lambda: interrupts.append(True),
|
||||||
|
dequeue=lambda pid: (dequeues.append(pid), False)[1],
|
||||||
|
)
|
||||||
|
assert result == CANCEL_UNKNOWN
|
||||||
|
assert dequeues == ["b"] # dequeue was attempted
|
||||||
|
assert interrupts == [] # interrupt was not called
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# HTTP contract tests: POST /api/jobs/{job_id}/cancel
|
# HTTP contract tests: POST /api/jobs/{job_id}/cancel
|
||||||
@ -277,12 +314,12 @@ class TestBatchCancelEndpoint:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_happy_path(self, aiohttp_client):
|
async def test_batch_happy_path(self, aiohttp_client):
|
||||||
queue = FakePromptQueue(
|
queue = FakePromptQueue(
|
||||||
running=[make_queue_item("a")],
|
running=[make_queue_item(_UUID_A)],
|
||||||
pending=[make_queue_item("b", number=1)],
|
pending=[make_queue_item(_UUID_B, number=1)],
|
||||||
)
|
)
|
||||||
client = await aiohttp_client(build_app(queue))
|
client = await aiohttp_client(build_app(queue))
|
||||||
|
|
||||||
resp = await client.post("/api/jobs/cancel", json={"job_ids": ["a", "b"]})
|
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_B]})
|
||||||
|
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
assert (await resp.json()) == {"cancelled": True}
|
assert (await resp.json()) == {"cancelled": True}
|
||||||
@ -294,18 +331,18 @@ class TestBatchCancelEndpoint:
|
|||||||
self, aiohttp_client
|
self, aiohttp_client
|
||||||
):
|
):
|
||||||
queue = FakePromptQueue(
|
queue = FakePromptQueue(
|
||||||
running=[make_queue_item("a")],
|
running=[make_queue_item(_UUID_A)],
|
||||||
pending=[make_queue_item("b", number=1)],
|
pending=[make_queue_item(_UUID_B, number=1)],
|
||||||
)
|
)
|
||||||
client = await aiohttp_client(build_app(queue))
|
client = await aiohttp_client(build_app(queue))
|
||||||
|
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
"/api/jobs/cancel", json={"job_ids": ["a", "missing", "b"]}
|
"/api/jobs/cancel", json={"job_ids": [_UUID_A, _UUID_MISSING, _UUID_B]}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status == 404
|
assert resp.status == 404
|
||||||
body = await resp.json()
|
body = await resp.json()
|
||||||
assert body["unknown_ids"] == ["missing"]
|
assert body["unknown_ids"] == [_UUID_MISSING]
|
||||||
# Fail-fast: nothing was cancelled — no partial side effects.
|
# Fail-fast: nothing was cancelled — no partial side effects.
|
||||||
assert queue.interrupt_count == 0
|
assert queue.interrupt_count == 0
|
||||||
assert len(queue.get_current_queue()[1]) == 1
|
assert len(queue.get_current_queue()[1]) == 1
|
||||||
@ -313,13 +350,13 @@ class TestBatchCancelEndpoint:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
|
async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client):
|
||||||
history = {
|
history = {
|
||||||
"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}},
|
_UUID_C: {"prompt": make_queue_item(_UUID_C), "outputs": {}, "status": {}},
|
||||||
"d": {"prompt": make_queue_item("d"), "outputs": {}, "status": {}},
|
_UUID_D: {"prompt": make_queue_item(_UUID_D), "outputs": {}, "status": {}},
|
||||||
}
|
}
|
||||||
queue = FakePromptQueue(history=history)
|
queue = FakePromptQueue(history=history)
|
||||||
client = await aiohttp_client(build_app(queue))
|
client = await aiohttp_client(build_app(queue))
|
||||||
|
|
||||||
resp = await client.post("/api/jobs/cancel", json={"job_ids": ["c", "d"]})
|
resp = await client.post("/api/jobs/cancel", json={"job_ids": [_UUID_C, _UUID_D]})
|
||||||
|
|
||||||
# All known but terminal: 200 with cancelled=false, nothing dispatched.
|
# All known but terminal: 200 with cancelled=false, nothing dispatched.
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
@ -334,3 +371,37 @@ class TestBatchCancelEndpoint:
|
|||||||
resp = await client.post("/api/jobs/cancel", json={})
|
resp = await client.post("/api/jobs/cancel", json={})
|
||||||
|
|
||||||
assert resp.status == 400
|
assert resp.status == 400
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_unhashable_element_is_400_not_500(self, aiohttp_client):
|
||||||
|
"""An unhashable element such as a dict or list must yield 400, not 500.
|
||||||
|
|
||||||
|
Previously, passing e.g. {"job_ids": [{}]} would reach the classify
|
||||||
|
loop where ``prompt_id in history`` raises TypeError on an unhashable
|
||||||
|
type, resulting in an unhandled 500. The input-validation guard must
|
||||||
|
catch this before any queue or history access.
|
||||||
|
"""
|
||||||
|
queue = FakePromptQueue()
|
||||||
|
client = await aiohttp_client(build_app(queue))
|
||||||
|
|
||||||
|
resp = await client.post("/api/jobs/cancel", json={"job_ids": [{}]})
|
||||||
|
|
||||||
|
assert resp.status == 400
|
||||||
|
body = await resp.json()
|
||||||
|
assert "invalid_ids" in body
|
||||||
|
# No queue side effects.
|
||||||
|
assert queue.interrupt_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_non_uuid_string_element_is_400(self, aiohttp_client):
|
||||||
|
"""A string that is not a valid UUID must be rejected with 400."""
|
||||||
|
queue = FakePromptQueue()
|
||||||
|
client = await aiohttp_client(build_app(queue))
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/jobs/cancel", json={"job_ids": ["not-a-uuid"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 400
|
||||||
|
body = await resp.json()
|
||||||
|
assert "invalid_ids" in body
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user