mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
fix(jobs): harden ids filter per review
- parse_ids_filter: one shared parser/validator for the ids query param, used
by the /api/jobs handler AND its tests (no more hand-copied wiring that can
drift from — and silently outlive a regression in — the shipped handler)
- present-but-empty ids (?ids=, ?ids=,,) is now a zero-match filter, not a
silent 'return the entire job history'
- bounded history lookup when an ids filter is present: a batch poll costs
O(requested ids), not O(total history)
- dedupe ids so the max-count cap bounds distinct values, not repeats
- .get('id') instead of j['id'] so a job missing its id degrades to no-match
rather than a 500
This commit is contained in:
parent
44fb02e510
commit
b656217514
@ -31,8 +31,9 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
# Maximum number of ids accepted by the `ids` filter on the jobs listing.
|
||||
# Bounds the work a single batch-poll request can ask for.
|
||||
# Maximum number of (distinct) ids accepted by the `ids` filter on the jobs
|
||||
# listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps
|
||||
# a batch-poll request at O(requested ids), not O(total history).
|
||||
MAX_JOB_IDS_FILTER = 100
|
||||
|
||||
|
||||
@ -55,6 +56,56 @@ def validate_job_id(value) -> str:
|
||||
return value
|
||||
|
||||
|
||||
class JobIdsFilterError(ValueError):
|
||||
"""Raised when the ``ids`` query-param value is malformed.
|
||||
|
||||
Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim
|
||||
with a 400 without re-deriving the message.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: dict):
|
||||
self.payload = payload
|
||||
super().__init__(payload.get("error", "invalid ids"))
|
||||
|
||||
|
||||
def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]:
|
||||
"""Parse the ``ids`` query-param value into a filter list.
|
||||
|
||||
Single source of truth for ``ids`` parsing/validation, shared by the HTTP
|
||||
handler and its tests so the two cannot drift.
|
||||
|
||||
Returns:
|
||||
- ``None`` when the param is absent (``ids_param is None``) -> no filter.
|
||||
- A de-duplicated list when present. An empty/blank value (``?ids=``,
|
||||
``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a
|
||||
zero-match filter -- NOT "return everything".
|
||||
|
||||
Raises:
|
||||
JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any
|
||||
id not in canonical UUID form. ``.payload`` is a 400-ready dict.
|
||||
"""
|
||||
if ids_param is None:
|
||||
return None
|
||||
# De-dupe up front: a repeated id must not count toward the cap or be
|
||||
# looked up twice. dict.fromkeys keeps first-seen order.
|
||||
ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip()))
|
||||
if len(ids_filter) > MAX_JOB_IDS_FILTER:
|
||||
raise JobIdsFilterError(
|
||||
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}
|
||||
)
|
||||
invalid_ids = []
|
||||
for jid in ids_filter:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid)
|
||||
if invalid_ids:
|
||||
raise JobIdsFilterError(
|
||||
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}
|
||||
)
|
||||
return ids_filter
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
@ -382,7 +433,8 @@ def get_all_jobs(
|
||||
history: Dict of history items keyed by prompt_id
|
||||
status_filter: List of statuses to include (from JobStatus.ALL)
|
||||
workflow_id: Filter by workflow ID
|
||||
ids: Restrict the result to these job ids (None/empty = no filter)
|
||||
ids: Restrict the result to these job ids. None = no filter; a present
|
||||
list (including empty) restricts to that set, so [] = zero matches
|
||||
sort_by: Field to sort by ('created_at', 'execution_duration')
|
||||
sort_order: 'asc' or 'desc'
|
||||
limit: Maximum number of items to return
|
||||
@ -396,6 +448,10 @@ def get_all_jobs(
|
||||
if status_filter is None:
|
||||
status_filter = JobStatus.ALL
|
||||
|
||||
# None => no id filter; a present list (including empty) restricts to that
|
||||
# set (empty => zero matches).
|
||||
id_set = set(ids) if ids is not None else None
|
||||
|
||||
if JobStatus.IN_PROGRESS in status_filter:
|
||||
for item in running:
|
||||
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
|
||||
@ -407,17 +463,29 @@ def get_all_jobs(
|
||||
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
|
||||
requested_history_statuses = history_statuses & set(status_filter)
|
||||
if requested_history_statuses:
|
||||
for prompt_id, history_item in history.items():
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
if id_set is not None:
|
||||
# Batch-poll fast path: history is keyed by id, so look up only the
|
||||
# requested ids instead of normalizing the whole (unbounded) history.
|
||||
for prompt_id in id_set:
|
||||
history_item = history.get(prompt_id)
|
||||
if history_item is None:
|
||||
continue
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
else:
|
||||
for prompt_id, history_item in history.items():
|
||||
job = normalize_history_item(prompt_id, history_item)
|
||||
if job.get('status') in requested_history_statuses:
|
||||
jobs.append(job)
|
||||
|
||||
if workflow_id:
|
||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||
|
||||
if ids:
|
||||
id_set = set(ids)
|
||||
jobs = [j for j in jobs if j['id'] in id_set]
|
||||
if id_set is not None:
|
||||
# `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a
|
||||
# job missing its id should degrade to "no match", not raise KeyError.
|
||||
jobs = [j for j in jobs if j.get('id') in id_set]
|
||||
|
||||
jobs = apply_sorting(jobs, sort_by, sort_order)
|
||||
|
||||
|
||||
33
server.py
33
server.py
@ -16,7 +16,8 @@ from comfy_execution.jobs import (
|
||||
cancel_job,
|
||||
CANCEL_PENDING,
|
||||
CANCEL_RUNNING,
|
||||
MAX_JOB_IDS_FILTER,
|
||||
parse_ids_filter,
|
||||
JobIdsFilterError,
|
||||
)
|
||||
import uuid
|
||||
import urllib
|
||||
@ -817,28 +818,14 @@ class PromptServer():
|
||||
)
|
||||
|
||||
# Optional batch filter: narrow the result to a known set of job ids
|
||||
# (e.g. polling a submitted batch in one request). Absent/empty means
|
||||
# no filter. Cap the count before validating the full list so an
|
||||
# oversized request fails fast, then reject any malformed id with 400.
|
||||
ids_filter = None
|
||||
if ids_param:
|
||||
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
|
||||
if len(ids_filter) > MAX_JOB_IDS_FILTER:
|
||||
return web.json_response(
|
||||
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
|
||||
status=400
|
||||
)
|
||||
invalid_ids = []
|
||||
for jid in ids_filter:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid)
|
||||
if invalid_ids:
|
||||
return web.json_response(
|
||||
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||
status=400
|
||||
)
|
||||
# (e.g. polling a submitted batch in one request). Parsing/validation
|
||||
# lives in parse_ids_filter so this handler and its tests share one
|
||||
# implementation. Absent => no filter; present-but-empty (`?ids=`,
|
||||
# `?ids=,,`) => zero matches, not "everything".
|
||||
try:
|
||||
ids_filter = parse_ids_filter(ids_param)
|
||||
except JobIdsFilterError as e:
|
||||
return web.json_response(e.payload, status=400)
|
||||
|
||||
if sort_by not in {'created_at', 'execution_duration'}:
|
||||
return web.json_response(
|
||||
|
||||
@ -10,10 +10,11 @@ Covers both layers:
|
||||
(a valid set narrows the response, an oversized set or a malformed id is
|
||||
rejected with 400).
|
||||
|
||||
As in ``jobs_cancel_test``, the HTTP layer is exercised against a small
|
||||
aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in
|
||||
``server.py``, driven by a fake queue. This keeps the test free of the heavy
|
||||
ComfyUI runtime (torch, nodes, ...) while still testing the real contract.
|
||||
The HTTP layer is exercised against a small aiohttp app whose handler calls the
|
||||
SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it
|
||||
cannot drift), driven by a fake queue. This keeps the test free of the heavy
|
||||
ComfyUI runtime (torch, nodes, ...) while still testing the real parsing
|
||||
contract.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@ -21,9 +22,10 @@ from aiohttp import web
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
JobIdsFilterError,
|
||||
MAX_JOB_IDS_FILTER,
|
||||
get_all_jobs,
|
||||
validate_job_id,
|
||||
parse_ids_filter,
|
||||
)
|
||||
|
||||
# Canonical UUID ids (the endpoint validates UUID format).
|
||||
@ -76,14 +78,18 @@ def test_ids_filter_absent_returns_all():
|
||||
assert total == 3
|
||||
|
||||
|
||||
def test_ids_filter_empty_list_returns_all():
|
||||
"""An empty list behaves like no filter (matches how status/workflow_id behave)."""
|
||||
def test_ids_filter_empty_list_returns_none():
|
||||
"""A present-but-empty ids list is a zero-match filter, not "no filter".
|
||||
|
||||
``None`` means "no id filter"; ``[]`` means "restrict to nothing".
|
||||
"""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
|
||||
jobs, _ = get_all_jobs(running, queued, {}, ids=[])
|
||||
jobs, total = get_all_jobs(running, queued, {}, ids=[])
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B}
|
||||
assert jobs == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
def test_ids_filter_unknown_id_silently_absent():
|
||||
@ -113,6 +119,43 @@ def test_ids_filter_composes_with_status():
|
||||
assert total == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# parse_ids_filter -- the shared parsing/validation (server.py + these tests)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_ids_absent_is_none():
|
||||
assert parse_ids_filter(None) is None
|
||||
|
||||
|
||||
def test_parse_ids_present_but_empty_is_empty_list():
|
||||
# `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None.
|
||||
assert parse_ids_filter("") == []
|
||||
assert parse_ids_filter(",,") == []
|
||||
|
||||
|
||||
def test_parse_ids_dedupes_preserving_order():
|
||||
assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B]
|
||||
|
||||
|
||||
def test_parse_ids_cap_counts_distinct_not_duplicates():
|
||||
# A small distinct set repeated far past the cap is still under it.
|
||||
repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER)
|
||||
assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B]
|
||||
# But more than MAX distinct ids is rejected.
|
||||
distinct = ",".join(
|
||||
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
|
||||
)
|
||||
with pytest.raises(JobIdsFilterError):
|
||||
parse_ids_filter(distinct)
|
||||
|
||||
|
||||
def test_parse_ids_invalid_raises_with_payload():
|
||||
with pytest.raises(JobIdsFilterError) as exc:
|
||||
parse_ids_filter(f"{_UUID_A},not-a-uuid")
|
||||
assert "not-a-uuid" in exc.value.payload["invalid_ids"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP contract for the ids query parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -134,32 +177,17 @@ class FakePromptQueue:
|
||||
|
||||
|
||||
def make_app(prompt_queue):
|
||||
"""Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring."""
|
||||
"""Build an aiohttp app whose handler calls the REAL parse_ids_filter.
|
||||
|
||||
No hand-copied parsing wiring, so this test cannot stay green while the
|
||||
shipped parsing in server.py regresses -- both go through parse_ids_filter.
|
||||
"""
|
||||
|
||||
async def get_jobs(request):
|
||||
query = request.rel_url.query
|
||||
|
||||
ids_param = query.get('ids')
|
||||
|
||||
ids_filter = None
|
||||
if ids_param:
|
||||
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
|
||||
if len(ids_filter) > MAX_JOB_IDS_FILTER:
|
||||
return web.json_response(
|
||||
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
|
||||
status=400
|
||||
)
|
||||
invalid_ids = []
|
||||
for jid in ids_filter:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid)
|
||||
if invalid_ids:
|
||||
return web.json_response(
|
||||
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||
status=400
|
||||
)
|
||||
try:
|
||||
ids_filter = parse_ids_filter(request.rel_url.query.get('ids'))
|
||||
except JobIdsFilterError as e:
|
||||
return web.json_response(e.payload, status=400)
|
||||
|
||||
running, queued = prompt_queue.get_current_queue_volatile()
|
||||
history = prompt_queue.get_history()
|
||||
@ -209,7 +237,11 @@ async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
|
||||
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1))
|
||||
# Distinct ids past the cap. (Repeats of one id are de-duped and would NOT
|
||||
# trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.)
|
||||
too_many = ",".join(
|
||||
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
|
||||
)
|
||||
resp = await client.get(f"/api/jobs?ids={too_many}")
|
||||
assert resp.status == 400
|
||||
|
||||
@ -232,3 +264,14 @@ async def test_http_ids_absent_returns_all(aiohttp_client, queue):
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue):
|
||||
"""`?ids=` (present but empty) is a zero-match filter, not "return all"."""
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get("/api/jobs?ids=")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["jobs"] == []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user