From b65621751412ef166792af651e0ed5ab75edecaf Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 30 Jun 2026 21:02:33 -0700 Subject: [PATCH] fix(jobs): harden ids filter per review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - parse_ids_filter: one shared parser/validator for the ids query param, used by the /api/jobs handler AND its tests (no more hand-copied wiring that can drift from — and silently outlive a regression in — the shipped handler) - present-but-empty ids (?ids=, ?ids=,,) is now a zero-match filter, not a silent 'return the entire job history' - bounded history lookup when an ids filter is present: a batch poll costs O(requested ids), not O(total history) - dedupe ids so the max-count cap bounds distinct values, not repeats - .get('id') instead of j['id'] so a job missing its id degrades to no-match rather than a 500 --- comfy_execution/jobs.py | 88 ++++++++++++++-- server.py | 33 ++---- tests-unit/jobs_list_test/jobs_list_test.py | 111 ++++++++++++++------ 3 files changed, 165 insertions(+), 67 deletions(-) diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index f0f9b32ec..508413b79 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -31,8 +31,9 @@ class JobStatus: ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] -# Maximum number of ids accepted by the `ids` filter on the jobs listing. -# Bounds the work a single batch-poll request can ask for. +# Maximum number of (distinct) ids accepted by the `ids` filter on the jobs +# listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps +# a batch-poll request at O(requested ids), not O(total history). MAX_JOB_IDS_FILTER = 100 @@ -55,6 +56,56 @@ def validate_job_id(value) -> str: return value +class JobIdsFilterError(ValueError): + """Raised when the ``ids`` query-param value is malformed. + + Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim + with a 400 without re-deriving the message. + """ + + def __init__(self, payload: dict): + self.payload = payload + super().__init__(payload.get("error", "invalid ids")) + + +def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]: + """Parse the ``ids`` query-param value into a filter list. + + Single source of truth for ``ids`` parsing/validation, shared by the HTTP + handler and its tests so the two cannot drift. + + Returns: + - ``None`` when the param is absent (``ids_param is None``) -> no filter. + - A de-duplicated list when present. An empty/blank value (``?ids=``, + ``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a + zero-match filter -- NOT "return everything". + + Raises: + JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any + id not in canonical UUID form. ``.payload`` is a 400-ready dict. + """ + if ids_param is None: + return None + # De-dupe up front: a repeated id must not count toward the cap or be + # looked up twice. dict.fromkeys keeps first-seen order. + ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip())) + if len(ids_filter) > MAX_JOB_IDS_FILTER: + raise JobIdsFilterError( + {"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"} + ) + invalid_ids = [] + for jid in ids_filter: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid) + if invalid_ids: + raise JobIdsFilterError( + {"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids} + ) + return ids_filter + + # Media types that can be previewed in the frontend PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) @@ -382,7 +433,8 @@ def get_all_jobs( history: Dict of history items keyed by prompt_id status_filter: List of statuses to include (from JobStatus.ALL) workflow_id: Filter by workflow ID - ids: Restrict the result to these job ids (None/empty = no filter) + ids: Restrict the result to these job ids. None = no filter; a present + list (including empty) restricts to that set, so [] = zero matches sort_by: Field to sort by ('created_at', 'execution_duration') sort_order: 'asc' or 'desc' limit: Maximum number of items to return @@ -396,6 +448,10 @@ def get_all_jobs( if status_filter is None: status_filter = JobStatus.ALL + # None => no id filter; a present list (including empty) restricts to that + # set (empty => zero matches). + id_set = set(ids) if ids is not None else None + if JobStatus.IN_PROGRESS in status_filter: for item in running: jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS)) @@ -407,17 +463,29 @@ def get_all_jobs( history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} requested_history_statuses = history_statuses & set(status_filter) if requested_history_statuses: - for prompt_id, history_item in history.items(): - job = normalize_history_item(prompt_id, history_item) - if job.get('status') in requested_history_statuses: - jobs.append(job) + if id_set is not None: + # Batch-poll fast path: history is keyed by id, so look up only the + # requested ids instead of normalizing the whole (unbounded) history. + for prompt_id in id_set: + history_item = history.get(prompt_id) + if history_item is None: + continue + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) + else: + for prompt_id, history_item in history.items(): + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] - if ids: - id_set = set(ids) - jobs = [j for j in jobs if j['id'] in id_set] + if id_set is not None: + # `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a + # job missing its id should degrade to "no match", not raise KeyError. + jobs = [j for j in jobs if j.get('id') in id_set] jobs = apply_sorting(jobs, sort_by, sort_order) diff --git a/server.py b/server.py index cb7916d6b..88fb288c8 100644 --- a/server.py +++ b/server.py @@ -16,7 +16,8 @@ from comfy_execution.jobs import ( cancel_job, CANCEL_PENDING, CANCEL_RUNNING, - MAX_JOB_IDS_FILTER, + parse_ids_filter, + JobIdsFilterError, ) import uuid import urllib @@ -817,28 +818,14 @@ class PromptServer(): ) # Optional batch filter: narrow the result to a known set of job ids - # (e.g. polling a submitted batch in one request). Absent/empty means - # no filter. Cap the count before validating the full list so an - # oversized request fails fast, then reject any malformed id with 400. - ids_filter = None - if ids_param: - ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()] - if len(ids_filter) > MAX_JOB_IDS_FILTER: - return web.json_response( - {"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}, - status=400 - ) - invalid_ids = [] - for jid in ids_filter: - try: - validate_job_id(jid) - except (ValueError, AttributeError): - invalid_ids.append(jid) - if invalid_ids: - return web.json_response( - {"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}, - status=400 - ) + # (e.g. polling a submitted batch in one request). Parsing/validation + # lives in parse_ids_filter so this handler and its tests share one + # implementation. Absent => no filter; present-but-empty (`?ids=`, + # `?ids=,,`) => zero matches, not "everything". + try: + ids_filter = parse_ids_filter(ids_param) + except JobIdsFilterError as e: + return web.json_response(e.payload, status=400) if sort_by not in {'created_at', 'execution_duration'}: return web.json_response( diff --git a/tests-unit/jobs_list_test/jobs_list_test.py b/tests-unit/jobs_list_test/jobs_list_test.py index bfeb4d727..0107284e0 100644 --- a/tests-unit/jobs_list_test/jobs_list_test.py +++ b/tests-unit/jobs_list_test/jobs_list_test.py @@ -10,10 +10,11 @@ Covers both layers: (a valid set narrows the response, an oversized set or a malformed id is rejected with 400). -As in ``jobs_cancel_test``, the HTTP layer is exercised against a small -aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in -``server.py``, driven by a fake queue. This keeps the test free of the heavy -ComfyUI runtime (torch, nodes, ...) while still testing the real contract. +The HTTP layer is exercised against a small aiohttp app whose handler calls the +SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it +cannot drift), driven by a fake queue. This keeps the test free of the heavy +ComfyUI runtime (torch, nodes, ...) while still testing the real parsing +contract. """ import pytest @@ -21,9 +22,10 @@ from aiohttp import web from comfy_execution.jobs import ( JobStatus, + JobIdsFilterError, MAX_JOB_IDS_FILTER, get_all_jobs, - validate_job_id, + parse_ids_filter, ) # Canonical UUID ids (the endpoint validates UUID format). @@ -76,14 +78,18 @@ def test_ids_filter_absent_returns_all(): assert total == 3 -def test_ids_filter_empty_list_returns_all(): - """An empty list behaves like no filter (matches how status/workflow_id behave).""" +def test_ids_filter_empty_list_returns_none(): + """A present-but-empty ids list is a zero-match filter, not "no filter". + + ``None`` means "no id filter"; ``[]`` means "restrict to nothing". + """ running = [make_queue_item(_UUID_A)] queued = [make_queue_item(_UUID_B)] - jobs, _ = get_all_jobs(running, queued, {}, ids=[]) + jobs, total = get_all_jobs(running, queued, {}, ids=[]) - assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B} + assert jobs == [] + assert total == 0 def test_ids_filter_unknown_id_silently_absent(): @@ -113,6 +119,43 @@ def test_ids_filter_composes_with_status(): assert total == 1 +# --------------------------------------------------------------------------- +# parse_ids_filter -- the shared parsing/validation (server.py + these tests) +# --------------------------------------------------------------------------- + + +def test_parse_ids_absent_is_none(): + assert parse_ids_filter(None) is None + + +def test_parse_ids_present_but_empty_is_empty_list(): + # `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None. + assert parse_ids_filter("") == [] + assert parse_ids_filter(",,") == [] + + +def test_parse_ids_dedupes_preserving_order(): + assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B] + + +def test_parse_ids_cap_counts_distinct_not_duplicates(): + # A small distinct set repeated far past the cap is still under it. + repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER) + assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B] + # But more than MAX distinct ids is rejected. + distinct = ",".join( + f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1) + ) + with pytest.raises(JobIdsFilterError): + parse_ids_filter(distinct) + + +def test_parse_ids_invalid_raises_with_payload(): + with pytest.raises(JobIdsFilterError) as exc: + parse_ids_filter(f"{_UUID_A},not-a-uuid") + assert "not-a-uuid" in exc.value.payload["invalid_ids"] + + # --------------------------------------------------------------------------- # HTTP contract for the ids query parameter # --------------------------------------------------------------------------- @@ -134,32 +177,17 @@ class FakePromptQueue: def make_app(prompt_queue): - """Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring.""" + """Build an aiohttp app whose handler calls the REAL parse_ids_filter. + + No hand-copied parsing wiring, so this test cannot stay green while the + shipped parsing in server.py regresses -- both go through parse_ids_filter. + """ async def get_jobs(request): - query = request.rel_url.query - - ids_param = query.get('ids') - - ids_filter = None - if ids_param: - ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()] - if len(ids_filter) > MAX_JOB_IDS_FILTER: - return web.json_response( - {"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}, - status=400 - ) - invalid_ids = [] - for jid in ids_filter: - try: - validate_job_id(jid) - except (ValueError, AttributeError): - invalid_ids.append(jid) - if invalid_ids: - return web.json_response( - {"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}, - status=400 - ) + try: + ids_filter = parse_ids_filter(request.rel_url.query.get('ids')) + except JobIdsFilterError as e: + return web.json_response(e.payload, status=400) running, queued = prompt_queue.get_current_queue_volatile() history = prompt_queue.get_history() @@ -209,7 +237,11 @@ async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue): async def test_http_ids_over_limit_returns_400(aiohttp_client, queue): client = await aiohttp_client(make_app(queue)) - too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1)) + # Distinct ids past the cap. (Repeats of one id are de-duped and would NOT + # trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.) + too_many = ",".join( + f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1) + ) resp = await client.get(f"/api/jobs?ids={too_many}") assert resp.status == 400 @@ -232,3 +264,14 @@ async def test_http_ids_absent_returns_all(aiohttp_client, queue): assert resp.status == 200 body = await resp.json() assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C} + + +@pytest.mark.asyncio +async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue): + """`?ids=` (present but empty) is a zero-match filter, not "return all".""" + client = await aiohttp_client(make_app(queue)) + + resp = await client.get("/api/jobs?ids=") + assert resp.status == 200 + body = await resp.json() + assert body["jobs"] == []