fix(jobs): harden ids filter per review

- parse_ids_filter: one shared parser/validator for the ids query param, used
  by the /api/jobs handler AND its tests (no more hand-copied wiring that can
  drift from — and silently outlive a regression in — the shipped handler)
- present-but-empty ids (?ids=, ?ids=,,) is now a zero-match filter, not a
  silent 'return the entire job history'
- bounded history lookup when an ids filter is present: a batch poll costs
  O(requested ids), not O(total history)
- dedupe ids so the max-count cap bounds distinct values, not repeats
- .get('id') instead of j['id'] so a job missing its id degrades to no-match
  rather than a 500
This commit is contained in:
Matt Miller 2026-06-30 21:02:33 -07:00
parent 44fb02e510
commit b656217514
3 changed files with 165 additions and 67 deletions

View File

@ -31,8 +31,9 @@ class JobStatus:
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Maximum number of ids accepted by the `ids` filter on the jobs listing.
# Bounds the work a single batch-poll request can ask for.
# Maximum number of (distinct) ids accepted by the `ids` filter on the jobs
# listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps
# a batch-poll request at O(requested ids), not O(total history).
MAX_JOB_IDS_FILTER = 100
@ -55,6 +56,56 @@ def validate_job_id(value) -> str:
return value
class JobIdsFilterError(ValueError):
"""Raised when the ``ids`` query-param value is malformed.
Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim
with a 400 without re-deriving the message.
"""
def __init__(self, payload: dict):
self.payload = payload
super().__init__(payload.get("error", "invalid ids"))
def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]:
"""Parse the ``ids`` query-param value into a filter list.
Single source of truth for ``ids`` parsing/validation, shared by the HTTP
handler and its tests so the two cannot drift.
Returns:
- ``None`` when the param is absent (``ids_param is None``) -> no filter.
- A de-duplicated list when present. An empty/blank value (``?ids=``,
``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a
zero-match filter -- NOT "return everything".
Raises:
JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any
id not in canonical UUID form. ``.payload`` is a 400-ready dict.
"""
if ids_param is None:
return None
# De-dupe up front: a repeated id must not count toward the cap or be
# looked up twice. dict.fromkeys keeps first-seen order.
ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip()))
if len(ids_filter) > MAX_JOB_IDS_FILTER:
raise JobIdsFilterError(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
raise JobIdsFilterError(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}
)
return ids_filter
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
@ -382,7 +433,8 @@ def get_all_jobs(
history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID
ids: Restrict the result to these job ids (None/empty = no filter)
ids: Restrict the result to these job ids. None = no filter; a present
list (including empty) restricts to that set, so [] = zero matches
sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc'
limit: Maximum number of items to return
@ -396,6 +448,10 @@ def get_all_jobs(
if status_filter is None:
status_filter = JobStatus.ALL
# None => no id filter; a present list (including empty) restricts to that
# set (empty => zero matches).
id_set = set(ids) if ids is not None else None
if JobStatus.IN_PROGRESS in status_filter:
for item in running:
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
@ -407,17 +463,29 @@ def get_all_jobs(
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
requested_history_statuses = history_statuses & set(status_filter)
if requested_history_statuses:
for prompt_id, history_item in history.items():
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if id_set is not None:
# Batch-poll fast path: history is keyed by id, so look up only the
# requested ids instead of normalizing the whole (unbounded) history.
for prompt_id in id_set:
history_item = history.get(prompt_id)
if history_item is None:
continue
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
else:
for prompt_id, history_item in history.items():
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
if ids:
id_set = set(ids)
jobs = [j for j in jobs if j['id'] in id_set]
if id_set is not None:
# `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a
# job missing its id should degrade to "no match", not raise KeyError.
jobs = [j for j in jobs if j.get('id') in id_set]
jobs = apply_sorting(jobs, sort_by, sort_order)

View File

@ -16,7 +16,8 @@ from comfy_execution.jobs import (
cancel_job,
CANCEL_PENDING,
CANCEL_RUNNING,
MAX_JOB_IDS_FILTER,
parse_ids_filter,
JobIdsFilterError,
)
import uuid
import urllib
@ -817,28 +818,14 @@ class PromptServer():
)
# Optional batch filter: narrow the result to a known set of job ids
# (e.g. polling a submitted batch in one request). Absent/empty means
# no filter. Cap the count before validating the full list so an
# oversized request fails fast, then reject any malformed id with 400.
ids_filter = None
if ids_param:
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
if len(ids_filter) > MAX_JOB_IDS_FILTER:
return web.json_response(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
status=400
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
return web.json_response(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400
)
# (e.g. polling a submitted batch in one request). Parsing/validation
# lives in parse_ids_filter so this handler and its tests share one
# implementation. Absent => no filter; present-but-empty (`?ids=`,
# `?ids=,,`) => zero matches, not "everything".
try:
ids_filter = parse_ids_filter(ids_param)
except JobIdsFilterError as e:
return web.json_response(e.payload, status=400)
if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response(

View File

@ -10,10 +10,11 @@ Covers both layers:
(a valid set narrows the response, an oversized set or a malformed id is
rejected with 400).
As in ``jobs_cancel_test``, the HTTP layer is exercised against a small
aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in
``server.py``, driven by a fake queue. This keeps the test free of the heavy
ComfyUI runtime (torch, nodes, ...) while still testing the real contract.
The HTTP layer is exercised against a small aiohttp app whose handler calls the
SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it
cannot drift), driven by a fake queue. This keeps the test free of the heavy
ComfyUI runtime (torch, nodes, ...) while still testing the real parsing
contract.
"""
import pytest
@ -21,9 +22,10 @@ from aiohttp import web
from comfy_execution.jobs import (
JobStatus,
JobIdsFilterError,
MAX_JOB_IDS_FILTER,
get_all_jobs,
validate_job_id,
parse_ids_filter,
)
# Canonical UUID ids (the endpoint validates UUID format).
@ -76,14 +78,18 @@ def test_ids_filter_absent_returns_all():
assert total == 3
def test_ids_filter_empty_list_returns_all():
"""An empty list behaves like no filter (matches how status/workflow_id behave)."""
def test_ids_filter_empty_list_returns_none():
"""A present-but-empty ids list is a zero-match filter, not "no filter".
``None`` means "no id filter"; ``[]`` means "restrict to nothing".
"""
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
jobs, _ = get_all_jobs(running, queued, {}, ids=[])
jobs, total = get_all_jobs(running, queued, {}, ids=[])
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B}
assert jobs == []
assert total == 0
def test_ids_filter_unknown_id_silently_absent():
@ -113,6 +119,43 @@ def test_ids_filter_composes_with_status():
assert total == 1
# ---------------------------------------------------------------------------
# parse_ids_filter -- the shared parsing/validation (server.py + these tests)
# ---------------------------------------------------------------------------
def test_parse_ids_absent_is_none():
assert parse_ids_filter(None) is None
def test_parse_ids_present_but_empty_is_empty_list():
# `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None.
assert parse_ids_filter("") == []
assert parse_ids_filter(",,") == []
def test_parse_ids_dedupes_preserving_order():
assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B]
def test_parse_ids_cap_counts_distinct_not_duplicates():
# A small distinct set repeated far past the cap is still under it.
repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER)
assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B]
# But more than MAX distinct ids is rejected.
distinct = ",".join(
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
)
with pytest.raises(JobIdsFilterError):
parse_ids_filter(distinct)
def test_parse_ids_invalid_raises_with_payload():
with pytest.raises(JobIdsFilterError) as exc:
parse_ids_filter(f"{_UUID_A},not-a-uuid")
assert "not-a-uuid" in exc.value.payload["invalid_ids"]
# ---------------------------------------------------------------------------
# HTTP contract for the ids query parameter
# ---------------------------------------------------------------------------
@ -134,32 +177,17 @@ class FakePromptQueue:
def make_app(prompt_queue):
"""Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring."""
"""Build an aiohttp app whose handler calls the REAL parse_ids_filter.
No hand-copied parsing wiring, so this test cannot stay green while the
shipped parsing in server.py regresses -- both go through parse_ids_filter.
"""
async def get_jobs(request):
query = request.rel_url.query
ids_param = query.get('ids')
ids_filter = None
if ids_param:
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
if len(ids_filter) > MAX_JOB_IDS_FILTER:
return web.json_response(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
status=400
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
return web.json_response(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400
)
try:
ids_filter = parse_ids_filter(request.rel_url.query.get('ids'))
except JobIdsFilterError as e:
return web.json_response(e.payload, status=400)
running, queued = prompt_queue.get_current_queue_volatile()
history = prompt_queue.get_history()
@ -209,7 +237,11 @@ async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1))
# Distinct ids past the cap. (Repeats of one id are de-duped and would NOT
# trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.)
too_many = ",".join(
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
)
resp = await client.get(f"/api/jobs?ids={too_many}")
assert resp.status == 400
@ -232,3 +264,14 @@ async def test_http_ids_absent_returns_all(aiohttp_client, queue):
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}
@pytest.mark.asyncio
async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue):
"""`?ids=` (present but empty) is a zero-match filter, not "return all"."""
client = await aiohttp_client(make_app(queue))
resp = await client.get("/api/jobs?ids=")
assert resp.status == 200
body = await resp.json()
assert body["jobs"] == []