This commit is contained in:
Matt Miller 2026-07-02 12:07:39 -07:00 committed by GitHub
commit 626ae96a13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 375 additions and 4 deletions

View File

@ -31,6 +31,12 @@ class JobStatus:
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Maximum number of (distinct) ids accepted by the `ids` filter on the jobs
# listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps
# a batch-poll request at O(requested ids), not O(total history).
MAX_JOB_IDS_FILTER = 100
def validate_job_id(value) -> str:
"""Validate a client-supplied job (prompt) id.
@ -50,6 +56,56 @@ def validate_job_id(value) -> str:
return value
class JobIdsFilterError(ValueError):
"""Raised when the ``ids`` query-param value is malformed.
Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim
with a 400 without re-deriving the message.
"""
def __init__(self, payload: dict):
self.payload = payload
super().__init__(payload.get("error", "invalid ids"))
def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]:
"""Parse the ``ids`` query-param value into a filter list.
Single source of truth for ``ids`` parsing/validation, shared by the HTTP
handler and its tests so the two cannot drift.
Returns:
- ``None`` when the param is absent (``ids_param is None``) -> no filter.
- A de-duplicated list when present. An empty/blank value (``?ids=``,
``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a
zero-match filter -- NOT "return everything".
Raises:
JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any
id not in canonical UUID form. ``.payload`` is a 400-ready dict.
"""
if ids_param is None:
return None
# De-dupe up front: a repeated id must not count toward the cap or be
# looked up twice. dict.fromkeys keeps first-seen order.
ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip()))
if len(ids_filter) > MAX_JOB_IDS_FILTER:
raise JobIdsFilterError(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
raise JobIdsFilterError(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}
)
return ids_filter
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
@ -362,6 +418,7 @@ def get_all_jobs(
history: dict,
status_filter: Optional[list[str]] = None,
workflow_id: Optional[str] = None,
ids: Optional[list[str]] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
limit: Optional[int] = None,
@ -376,6 +433,8 @@ def get_all_jobs(
history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID
ids: Restrict the result to these job ids. None = no filter; a present
list (including empty) restricts to that set, so [] = zero matches
sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc'
limit: Maximum number of items to return
@ -389,6 +448,10 @@ def get_all_jobs(
if status_filter is None:
status_filter = JobStatus.ALL
# None => no id filter; a present list (including empty) restricts to that
# set (empty => zero matches).
id_set = set(ids) if ids is not None else None
if JobStatus.IN_PROGRESS in status_filter:
for item in running:
jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS))
@ -400,14 +463,30 @@ def get_all_jobs(
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
requested_history_statuses = history_statuses & set(status_filter)
if requested_history_statuses:
for prompt_id, history_item in history.items():
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if id_set is not None:
# Batch-poll fast path: history is keyed by id, so look up only the
# requested ids instead of normalizing the whole (unbounded) history.
for prompt_id in id_set:
history_item = history.get(prompt_id)
if history_item is None:
continue
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
else:
for prompt_id, history_item in history.items():
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
if id_set is not None:
# `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a
# job missing its id should degrade to "no match", not raise KeyError.
jobs = [j for j in jobs if j.get('id') in id_set]
jobs = apply_sorting(jobs, sort_by, sort_order)
total_count = len(jobs)

View File

@ -16,6 +16,8 @@ from comfy_execution.jobs import (
cancel_job,
CANCEL_PENDING,
CANCEL_RUNNING,
parse_ids_filter,
JobIdsFilterError,
)
import uuid
import urllib
@ -791,6 +793,7 @@ class PromptServer():
Query parameters:
status: Filter by status (comma-separated): pending, in_progress, completed, failed
workflow_id: Filter by workflow ID
ids: Filter by job id (comma-separated UUIDs, max 100)
sort_by: Sort field: created_at (default), execution_duration
sort_order: Sort direction: asc, desc (default)
limit: Max items to return (positive integer)
@ -800,6 +803,7 @@ class PromptServer():
status_param = query.get('status')
workflow_id = query.get('workflow_id')
ids_param = query.get('ids')
sort_by = query.get('sort_by', 'created_at').lower()
sort_order = query.get('sort_order', 'desc').lower()
@ -813,6 +817,16 @@ class PromptServer():
status=400
)
# Optional batch filter: narrow the result to a known set of job ids
# (e.g. polling a submitted batch in one request). Parsing/validation
# lives in parse_ids_filter so this handler and its tests share one
# implementation. Absent => no filter; present-but-empty (`?ids=`,
# `?ids=,,`) => zero matches, not "everything".
try:
ids_filter = parse_ids_filter(ids_param)
except JobIdsFilterError as e:
return web.json_response(e.payload, status=400)
if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response(
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
@ -864,6 +878,7 @@ class PromptServer():
running, queued, history,
status_filter=status_filter,
workflow_id=workflow_id,
ids=ids_filter,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,

View File

View File

@ -0,0 +1,277 @@
"""Tests for the ``ids`` batch filter on the jobs listing endpoint.
Covers both layers:
* the pure ``comfy_execution.jobs.get_all_jobs`` filtering logic (the ``ids``
argument narrows the result, composes with ``status_filter``, and silently
ignores ids that match nothing), and
* the HTTP contract of ``GET /api/jobs`` for the ``ids`` query parameter
(a valid set narrows the response, an oversized set or a malformed id is
rejected with 400).
The HTTP layer is exercised against a small aiohttp app whose handler calls the
SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it
cannot drift), driven by a fake queue. This keeps the test free of the heavy
ComfyUI runtime (torch, nodes, ...) while still testing the real parsing
contract.
"""
import pytest
from aiohttp import web
from comfy_execution.jobs import (
JobStatus,
JobIdsFilterError,
MAX_JOB_IDS_FILTER,
get_all_jobs,
parse_ids_filter,
)
# Canonical UUID ids (the endpoint validates UUID format).
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
def make_queue_item(prompt_id, priority=0):
"""Build a queue tuple shaped like the real ones (5 elements, id at index 1)."""
return (priority, prompt_id, {}, {}, [])
def make_history_item(status_str="success"):
"""Build a history item dict shaped like the real ones."""
return {
"prompt": (0, "", {}, {}, []),
"status": {"status_str": status_str, "messages": []},
"outputs": {},
}
# ---------------------------------------------------------------------------
# Pure get_all_jobs filtering logic
# ---------------------------------------------------------------------------
def test_ids_filter_returns_only_requested():
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
history = {_UUID_C: make_history_item()}
jobs, total = get_all_jobs(running, queued, history, ids=[_UUID_A, _UUID_C])
returned = {j["id"] for j in jobs}
assert returned == {_UUID_A, _UUID_C}
assert total == 2
assert _UUID_B not in returned
def test_ids_filter_absent_returns_all():
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
history = {_UUID_C: make_history_item()}
jobs, total = get_all_jobs(running, queued, history)
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B, _UUID_C}
assert total == 3
def test_ids_filter_empty_list_returns_none():
"""A present-but-empty ids list is a zero-match filter, not "no filter".
``None`` means "no id filter"; ``[]`` means "restrict to nothing".
"""
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
jobs, total = get_all_jobs(running, queued, {}, ids=[])
assert jobs == []
assert total == 0
def test_ids_filter_unknown_id_silently_absent():
"""An id that matches nothing is simply not present (no error)."""
running = [make_queue_item(_UUID_A)]
jobs, total = get_all_jobs(running, [], {}, ids=[_UUID_A, _UUID_MISSING])
assert {j["id"] for j in jobs} == {_UUID_A}
assert total == 1
def test_ids_filter_composes_with_status():
"""ids only narrows; it composes with the status filter."""
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
history = {_UUID_C: make_history_item()}
# Request A and C by id, but restrict to in_progress only -> just A.
jobs, total = get_all_jobs(
running, queued, history,
status_filter=[JobStatus.IN_PROGRESS],
ids=[_UUID_A, _UUID_C],
)
assert {j["id"] for j in jobs} == {_UUID_A}
assert total == 1
# ---------------------------------------------------------------------------
# parse_ids_filter -- the shared parsing/validation (server.py + these tests)
# ---------------------------------------------------------------------------
def test_parse_ids_absent_is_none():
assert parse_ids_filter(None) is None
def test_parse_ids_present_but_empty_is_empty_list():
# `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None.
assert parse_ids_filter("") == []
assert parse_ids_filter(",,") == []
def test_parse_ids_dedupes_preserving_order():
assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B]
def test_parse_ids_cap_counts_distinct_not_duplicates():
# A small distinct set repeated far past the cap is still under it.
repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER)
assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B]
# But more than MAX distinct ids is rejected.
distinct = ",".join(
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
)
with pytest.raises(JobIdsFilterError):
parse_ids_filter(distinct)
def test_parse_ids_invalid_raises_with_payload():
with pytest.raises(JobIdsFilterError) as exc:
parse_ids_filter(f"{_UUID_A},not-a-uuid")
assert "not-a-uuid" in exc.value.payload["invalid_ids"]
# ---------------------------------------------------------------------------
# HTTP contract for the ids query parameter
# ---------------------------------------------------------------------------
class FakePromptQueue:
"""Minimal stand-in exposing the accessors get_jobs uses."""
def __init__(self, running=None, queued=None, history=None):
self._running = list(running or [])
self._queued = list(queued or [])
self._history = dict(history or {})
def get_current_queue_volatile(self):
return (list(self._running), list(self._queued))
def get_history(self):
return dict(self._history)
def make_app(prompt_queue):
"""Build an aiohttp app whose handler calls the REAL parse_ids_filter.
No hand-copied parsing wiring, so this test cannot stay green while the
shipped parsing in server.py regresses -- both go through parse_ids_filter.
"""
async def get_jobs(request):
try:
ids_filter = parse_ids_filter(request.rel_url.query.get('ids'))
except JobIdsFilterError as e:
return web.json_response(e.payload, status=400)
running, queued = prompt_queue.get_current_queue_volatile()
history = prompt_queue.get_history()
jobs, total = get_all_jobs(running, queued, history, ids=ids_filter)
return web.json_response({
'jobs': jobs,
'pagination': {'total': total},
})
app = web.Application()
app.router.add_get('/api/jobs', get_jobs)
return app
@pytest.fixture
def queue():
return FakePromptQueue(
running=[make_queue_item(_UUID_A)],
queued=[make_queue_item(_UUID_B)],
history={_UUID_C: make_history_item()},
)
@pytest.mark.asyncio
async def test_http_ids_filter_narrows(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_C}")
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_C}
@pytest.mark.asyncio
async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_MISSING}")
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A}
@pytest.mark.asyncio
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
# Distinct ids past the cap. (Repeats of one id are de-duped and would NOT
# trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.)
too_many = ",".join(
f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1)
)
resp = await client.get(f"/api/jobs?ids={too_many}")
assert resp.status == 400
@pytest.mark.asyncio
async def test_http_ids_invalid_id_returns_400(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get(f"/api/jobs?ids={_UUID_A},not-a-uuid")
assert resp.status == 400
body = await resp.json()
assert "not-a-uuid" in body["invalid_ids"]
@pytest.mark.asyncio
async def test_http_ids_absent_returns_all(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get("/api/jobs")
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}
@pytest.mark.asyncio
async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue):
"""`?ids=` (present but empty) is a zero-match filter, not "return all"."""
client = await aiohttp_client(make_app(queue))
resp = await client.get("/api/jobs?ids=")
assert resp.status == 200
body = await resp.json()
assert body["jobs"] == []