fix(jobs): harden ids filter per review

- parse_ids_filter: one shared parser/validator for the ids query param, used
  by the /api/jobs handler AND its tests (no more hand-copied wiring that can
  drift from — and silently outlive a regression in — the shipped handler)
- present-but-empty ids (?ids=, ?ids=,,) is now a zero-match filter, not a
  silent 'return the entire job history'
- bounded history lookup when an ids filter is present: a batch poll costs
  O(requested ids), not O(total history)
- dedupe ids so the max-count cap bounds distinct values, not repeats
- .get('id') instead of j['id'] so a job missing its id degrades to no-match
  rather than a 500
This commit is contained in:
Matt Miller 2026-06-30 21:02:33 -07:00
parent 44fb02e510
commit b656217514
3 changed files with 165 additions and 67 deletions

View File

@ -31,8 +31,9 @@ class JobStatus:
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Maximum number of ids accepted by the `ids` filter on the jobs listing. # Maximum number of (distinct) ids accepted by the `ids` filter on the jobs
# Bounds the work a single batch-poll request can ask for. # listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps
# a batch-poll request at O(requested ids), not O(total history).
MAX_JOB_IDS_FILTER = 100 MAX_JOB_IDS_FILTER = 100
@ -55,6 +56,56 @@ def validate_job_id(value) -> str:
return value return value
class JobIdsFilterError(ValueError):
"""Raised when the ``ids`` query-param value is malformed.
Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim
with a 400 without re-deriving the message.
"""
def __init__(self, payload: dict):
self.payload = payload
super().__init__(payload.get("error", "invalid ids"))
def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]:
"""Parse the ``ids`` query-param value into a filter list.
Single source of truth for ``ids`` parsing/validation, shared by the HTTP
handler and its tests so the two cannot drift.
Returns:
- ``None`` when the param is absent (``ids_param is None``) -> no filter.
- A de-duplicated list when present. An empty/blank value (``?ids=``,
``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a
zero-match filter -- NOT "return everything".
Raises:
JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any
id not in canonical UUID form. ``.payload`` is a 400-ready dict.
"""
if ids_param is None:
return None
# De-dupe up front: a repeated id must not count toward the cap or be
# looked up twice. dict.fromkeys keeps first-seen order.
ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip()))
if len(ids_filter) > MAX_JOB_IDS_FILTER:
raise JobIdsFilterError(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
raise JobIdsFilterError(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}
)
return ids_filter
# Media types that can be previewed in the frontend # Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
@ -382,7 +433,8 @@ def get_all_jobs(
history: Dict of history items keyed by prompt_id history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL) status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID workflow_id: Filter by workflow ID
ids: Restrict the result to these job ids (None/empty = no filter) ids: Restrict the result to these job ids. None = no filter; a present
list (including empty) restricts to that set, so [] = zero matches
sort_by: Field to sort by ('created_at', 'execution_duration') sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc' sort_order: 'asc' or 'desc'
limit: Maximum number of items to return limit: Maximum number of items to return
@ -396,6 +448,10 @@ def get_all_jobs(
if status_filter is None: if status_filter is None:
status_filter = JobStatus.ALL status_filter = JobStatus.ALL
# None => no id filter; a present list (including empty) restricts to that
# set (empty => zero matches).
id_set = set(ids) if ids is not None else None
if JobStatus.IN_PROGRESS in status_filter: if JobStatus.IN_PROGRESS in status_filter:
for item in running: for item in running:
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS)) jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
@ -407,17 +463,29 @@ def get_all_jobs(
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
requested_history_statuses = history_statuses & set(status_filter) requested_history_statuses = history_statuses & set(status_filter)
if requested_history_statuses: if requested_history_statuses:
for prompt_id, history_item in history.items(): if id_set is not None:
job = normalize_history_item(prompt_id, history_item) # Batch-poll fast path: history is keyed by id, so look up only the
if job.get('status') in requested_history_statuses: # requested ids instead of normalizing the whole (unbounded) history.
jobs.append(job) for prompt_id in id_set:
history_item = history.get(prompt_id)
if history_item is None:
continue
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
else:
for prompt_id, history_item in history.items():
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if workflow_id: if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
if ids: if id_set is not None:
id_set = set(ids) # `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a
jobs = [j for j in jobs if j['id'] in id_set] # job missing its id should degrade to "no match", not raise KeyError.
jobs = [j for j in jobs if j.get('id') in id_set]
jobs = apply_sorting(jobs, sort_by, sort_order) jobs = apply_sorting(jobs, sort_by, sort_order)

View File

@ -16,7 +16,8 @@ from comfy_execution.jobs import (
cancel_job, cancel_job,
CANCEL_PENDING, CANCEL_PENDING,
CANCEL_RUNNING, CANCEL_RUNNING,
MAX_JOB_IDS_FILTER, parse_ids_filter,
JobIdsFilterError,
) )
import uuid import uuid
import urllib import urllib
@ -817,28 +818,14 @@ class PromptServer():
) )
# Optional batch filter: narrow the result to a known set of job ids # Optional batch filter: narrow the result to a known set of job ids
# (e.g. polling a submitted batch in one request). Absent/empty means # (e.g. polling a submitted batch in one request). Parsing/validation
# no filter. Cap the count before validating the full list so an # lives in parse_ids_filter so this handler and its tests share one
# oversized request fails fast, then reject any malformed id with 400. # implementation. Absent => no filter; present-but-empty (`?ids=`,
ids_filter = None # `?ids=,,`) => zero matches, not "everything".
if ids_param: try:
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()] ids_filter = parse_ids_filter(ids_param)
if len(ids_filter) > MAX_JOB_IDS_FILTER: except JobIdsFilterError as e:
return web.json_response( return web.json_response(e.payload, status=400)
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
status=400
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
return web.json_response(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400
)
if sort_by not in {'created_at', 'execution_duration'}: if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response( return web.json_response(

View File

@ -10,10 +10,11 @@ Covers both layers:
(a valid set narrows the response, an oversized set or a malformed id is (a valid set narrows the response, an oversized set or a malformed id is
rejected with 400). rejected with 400).
As in ``jobs_cancel_test``, the HTTP layer is exercised against a small The HTTP layer is exercised against a small aiohttp app whose handler calls the
aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it
``server.py``, driven by a fake queue. This keeps the test free of the heavy cannot drift), driven by a fake queue. This keeps the test free of the heavy
ComfyUI runtime (torch, nodes, ...) while still testing the real contract. ComfyUI runtime (torch, nodes, ...) while still testing the real parsing
contract.
""" """
import pytest import pytest
@ -21,9 +22,10 @@ from aiohttp import web
from comfy_execution.jobs import ( from comfy_execution.jobs import (
JobStatus, JobStatus,
JobIdsFilterError,
MAX_JOB_IDS_FILTER, MAX_JOB_IDS_FILTER,
get_all_jobs, get_all_jobs,
validate_job_id, parse_ids_filter,
) )
# Canonical UUID ids (the endpoint validates UUID format). # Canonical UUID ids (the endpoint validates UUID format).
@ -76,14 +78,18 @@ def test_ids_filter_absent_returns_all():
assert total == 3 assert total == 3
def test_ids_filter_empty_list_returns_all(): def test_ids_filter_empty_list_returns_none():
"""An empty list behaves like no filter (matches how status/workflow_id behave).""" """A present-but-empty ids list is a zero-match filter, not "no filter".
``None`` means "no id filter"; ``[]`` means "restrict to nothing".
"""
running = [make_queue_item(_UUID_A)] running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)] queued = [make_queue_item(_UUID_B)]
jobs, _ = get_all_jobs(running, queued, {}, ids=[]) jobs, total = get_all_jobs(running, queued, {}, ids=[])
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B} assert jobs == []
assert total == 0
def test_ids_filter_unknown_id_silently_absent(): def test_ids_filter_unknown_id_silently_absent():
@ -113,6 +119,43 @@ def test_ids_filter_composes_with_status():
assert total == 1 assert total == 1
# ---------------------------------------------------------------------------
# parse_ids_filter -- the shared parsing/validation (server.py + these tests)
# ---------------------------------------------------------------------------
def test_parse_ids_absent_is_none():
assert parse_ids_filter(None) is None
def test_parse_ids_present_but_empty_is_empty_list():
# `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None.
assert parse_ids_filter("") == []
assert parse_ids_filter(",,") == []
def test_parse_ids_dedupes_preserving_order():
assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B]
def test_parse_ids_cap_counts_distinct_not_duplicates():
# A small distinct set repeated far past the cap is still under it.
repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER)
assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B]
# But more than MAX distinct ids is rejected.
distinct = ",".join(
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
)
with pytest.raises(JobIdsFilterError):
parse_ids_filter(distinct)
def test_parse_ids_invalid_raises_with_payload():
with pytest.raises(JobIdsFilterError) as exc:
parse_ids_filter(f"{_UUID_A},not-a-uuid")
assert "not-a-uuid" in exc.value.payload["invalid_ids"]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# HTTP contract for the ids query parameter # HTTP contract for the ids query parameter
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -134,32 +177,17 @@ class FakePromptQueue:
def make_app(prompt_queue): def make_app(prompt_queue):
"""Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring.""" """Build an aiohttp app whose handler calls the REAL parse_ids_filter.
No hand-copied parsing wiring, so this test cannot stay green while the
shipped parsing in server.py regresses -- both go through parse_ids_filter.
"""
async def get_jobs(request): async def get_jobs(request):
query = request.rel_url.query try:
ids_filter = parse_ids_filter(request.rel_url.query.get('ids'))
ids_param = query.get('ids') except JobIdsFilterError as e:
return web.json_response(e.payload, status=400)
ids_filter = None
if ids_param:
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
if len(ids_filter) > MAX_JOB_IDS_FILTER:
return web.json_response(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
status=400
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
return web.json_response(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400
)
running, queued = prompt_queue.get_current_queue_volatile() running, queued = prompt_queue.get_current_queue_volatile()
history = prompt_queue.get_history() history = prompt_queue.get_history()
@ -209,7 +237,11 @@ async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue): async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue)) client = await aiohttp_client(make_app(queue))
too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1)) # Distinct ids past the cap. (Repeats of one id are de-duped and would NOT
# trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.)
too_many = ",".join(
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
)
resp = await client.get(f"/api/jobs?ids={too_many}") resp = await client.get(f"/api/jobs?ids={too_many}")
assert resp.status == 400 assert resp.status == 400
@ -232,3 +264,14 @@ async def test_http_ids_absent_returns_all(aiohttp_client, queue):
assert resp.status == 200 assert resp.status == 200
body = await resp.json() body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C} assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}
@pytest.mark.asyncio
async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue):
"""`?ids=` (present but empty) is a zero-match filter, not "return all"."""
client = await aiohttp_client(make_app(queue))
resp = await client.get("/api/jobs?ids=")
assert resp.status == 200
body = await resp.json()
assert body["jobs"] == []