diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fa3ab0faf..508413b79 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -31,6 +31,12 @@ class JobStatus: ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] +# Maximum number of (distinct) ids accepted by the `ids` filter on the jobs +# listing. Caps request size; the bounded id-lookup in get_all_jobs then keeps +# a batch-poll request at O(requested ids), not O(total history). +MAX_JOB_IDS_FILTER = 100 + + def validate_job_id(value) -> str: """Validate a client-supplied job (prompt) id. @@ -50,6 +56,56 @@ def validate_job_id(value) -> str: return value +class JobIdsFilterError(ValueError): + """Raised when the ``ids`` query-param value is malformed. + + Carries an HTTP-ready ``payload`` dict so the caller can return it verbatim + with a 400 without re-deriving the message. + """ + + def __init__(self, payload: dict): + self.payload = payload + super().__init__(payload.get("error", "invalid ids")) + + +def parse_ids_filter(ids_param: Optional[str]) -> Optional[list[str]]: + """Parse the ``ids`` query-param value into a filter list. + + Single source of truth for ``ids`` parsing/validation, shared by the HTTP + handler and its tests so the two cannot drift. + + Returns: + - ``None`` when the param is absent (``ids_param is None``) -> no filter. + - A de-duplicated list when present. An empty/blank value (``?ids=``, + ``?ids=,,``) yields ``[]``, which ``get_all_jobs`` treats as a + zero-match filter -- NOT "return everything". + + Raises: + JobIdsFilterError: more than ``MAX_JOB_IDS_FILTER`` distinct ids, or any + id not in canonical UUID form. ``.payload`` is a 400-ready dict. + """ + if ids_param is None: + return None + # De-dupe up front: a repeated id must not count toward the cap or be + # looked up twice. dict.fromkeys keeps first-seen order. + ids_filter = list(dict.fromkeys(i.strip() for i in ids_param.split(',') if i.strip())) + if len(ids_filter) > MAX_JOB_IDS_FILTER: + raise JobIdsFilterError( + {"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"} + ) + invalid_ids = [] + for jid in ids_filter: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid) + if invalid_ids: + raise JobIdsFilterError( + {"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids} + ) + return ids_filter + + # Media types that can be previewed in the frontend PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) @@ -362,6 +418,7 @@ def get_all_jobs( history: dict, status_filter: Optional[list[str]] = None, workflow_id: Optional[str] = None, + ids: Optional[list[str]] = None, sort_by: str = "created_at", sort_order: str = "desc", limit: Optional[int] = None, @@ -376,6 +433,8 @@ def get_all_jobs( history: Dict of history items keyed by prompt_id status_filter: List of statuses to include (from JobStatus.ALL) workflow_id: Filter by workflow ID + ids: Restrict the result to these job ids. None = no filter; a present + list (including empty) restricts to that set, so [] = zero matches sort_by: Field to sort by ('created_at', 'execution_duration') sort_order: 'asc' or 'desc' limit: Maximum number of items to return @@ -389,6 +448,10 @@ def get_all_jobs( if status_filter is None: status_filter = JobStatus.ALL + # None => no id filter; a present list (including empty) restricts to that + # set (empty => zero matches). + id_set = set(ids) if ids is not None else None + if JobStatus.IN_PROGRESS in status_filter: for item in running: jobs.append(normalize_queue_item(item, JobStatus.IN_PROGRESS)) @@ -400,14 +463,30 @@ def get_all_jobs( history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} requested_history_statuses = history_statuses & set(status_filter) if requested_history_statuses: - for prompt_id, history_item in history.items(): - job = normalize_history_item(prompt_id, history_item) - if job.get('status') in requested_history_statuses: - jobs.append(job) + if id_set is not None: + # Batch-poll fast path: history is keyed by id, so look up only the + # requested ids instead of normalizing the whole (unbounded) history. + for prompt_id in id_set: + history_item = history.get(prompt_id) + if history_item is None: + continue + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) + else: + for prompt_id, history_item in history.items(): + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] + if id_set is not None: + # `.get('id')` (not `j['id']`): prune_dict can drop a None id, and a + # job missing its id should degrade to "no match", not raise KeyError. + jobs = [j for j in jobs if j.get('id') in id_set] + jobs = apply_sorting(jobs, sort_by, sort_order) total_count = len(jobs) diff --git a/server.py b/server.py index 361850f38..88fb288c8 100644 --- a/server.py +++ b/server.py @@ -16,6 +16,8 @@ from comfy_execution.jobs import ( cancel_job, CANCEL_PENDING, CANCEL_RUNNING, + parse_ids_filter, + JobIdsFilterError, ) import uuid import urllib @@ -791,6 +793,7 @@ class PromptServer(): Query parameters: status: Filter by status (comma-separated): pending, in_progress, completed, failed workflow_id: Filter by workflow ID + ids: Filter by job id (comma-separated UUIDs, max 100) sort_by: Sort field: created_at (default), execution_duration sort_order: Sort direction: asc, desc (default) limit: Max items to return (positive integer) @@ -800,6 +803,7 @@ class PromptServer(): status_param = query.get('status') workflow_id = query.get('workflow_id') + ids_param = query.get('ids') sort_by = query.get('sort_by', 'created_at').lower() sort_order = query.get('sort_order', 'desc').lower() @@ -813,6 +817,16 @@ class PromptServer(): status=400 ) + # Optional batch filter: narrow the result to a known set of job ids + # (e.g. polling a submitted batch in one request). Parsing/validation + # lives in parse_ids_filter so this handler and its tests share one + # implementation. Absent => no filter; present-but-empty (`?ids=`, + # `?ids=,,`) => zero matches, not "everything". + try: + ids_filter = parse_ids_filter(ids_param) + except JobIdsFilterError as e: + return web.json_response(e.payload, status=400) + if sort_by not in {'created_at', 'execution_duration'}: return web.json_response( {"error": "sort_by must be 'created_at' or 'execution_duration'"}, @@ -864,6 +878,7 @@ class PromptServer(): running, queued, history, status_filter=status_filter, workflow_id=workflow_id, + ids=ids_filter, sort_by=sort_by, sort_order=sort_order, limit=limit, diff --git a/tests-unit/jobs_list_test/__init__.py b/tests-unit/jobs_list_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/jobs_list_test/jobs_list_test.py b/tests-unit/jobs_list_test/jobs_list_test.py new file mode 100644 index 000000000..0107284e0 --- /dev/null +++ b/tests-unit/jobs_list_test/jobs_list_test.py @@ -0,0 +1,277 @@ +"""Tests for the ``ids`` batch filter on the jobs listing endpoint. + +Covers both layers: + +* the pure ``comfy_execution.jobs.get_all_jobs`` filtering logic (the ``ids`` + argument narrows the result, composes with ``status_filter``, and silently + ignores ids that match nothing), and + +* the HTTP contract of ``GET /api/jobs`` for the ``ids`` query parameter + (a valid set narrows the response, an oversized set or a malformed id is + rejected with 400). + +The HTTP layer is exercised against a small aiohttp app whose handler calls the +SAME ``parse_ids_filter`` that ``server.py`` uses (no hand-copied wiring, so it +cannot drift), driven by a fake queue. This keeps the test free of the heavy +ComfyUI runtime (torch, nodes, ...) while still testing the real parsing +contract. +""" + +import pytest +from aiohttp import web + +from comfy_execution.jobs import ( + JobStatus, + JobIdsFilterError, + MAX_JOB_IDS_FILTER, + get_all_jobs, + parse_ids_filter, +) + +# Canonical UUID ids (the endpoint validates UUID format). +_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa" +_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb" +_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc" +_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff" + + +def make_queue_item(prompt_id, priority=0): + """Build a queue tuple shaped like the real ones (5 elements, id at index 1).""" + return (priority, prompt_id, {}, {}, []) + + +def make_history_item(status_str="success"): + """Build a history item dict shaped like the real ones.""" + return { + "prompt": (0, "", {}, {}, []), + "status": {"status_str": status_str, "messages": []}, + "outputs": {}, + } + + +# --------------------------------------------------------------------------- +# Pure get_all_jobs filtering logic +# --------------------------------------------------------------------------- + + +def test_ids_filter_returns_only_requested(): + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + history = {_UUID_C: make_history_item()} + + jobs, total = get_all_jobs(running, queued, history, ids=[_UUID_A, _UUID_C]) + + returned = {j["id"] for j in jobs} + assert returned == {_UUID_A, _UUID_C} + assert total == 2 + assert _UUID_B not in returned + + +def test_ids_filter_absent_returns_all(): + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + history = {_UUID_C: make_history_item()} + + jobs, total = get_all_jobs(running, queued, history) + + assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B, _UUID_C} + assert total == 3 + + +def test_ids_filter_empty_list_returns_none(): + """A present-but-empty ids list is a zero-match filter, not "no filter". + + ``None`` means "no id filter"; ``[]`` means "restrict to nothing". + """ + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + + jobs, total = get_all_jobs(running, queued, {}, ids=[]) + + assert jobs == [] + assert total == 0 + + +def test_ids_filter_unknown_id_silently_absent(): + """An id that matches nothing is simply not present (no error).""" + running = [make_queue_item(_UUID_A)] + + jobs, total = get_all_jobs(running, [], {}, ids=[_UUID_A, _UUID_MISSING]) + + assert {j["id"] for j in jobs} == {_UUID_A} + assert total == 1 + + +def test_ids_filter_composes_with_status(): + """ids only narrows; it composes with the status filter.""" + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + history = {_UUID_C: make_history_item()} + + # Request A and C by id, but restrict to in_progress only -> just A. + jobs, total = get_all_jobs( + running, queued, history, + status_filter=[JobStatus.IN_PROGRESS], + ids=[_UUID_A, _UUID_C], + ) + + assert {j["id"] for j in jobs} == {_UUID_A} + assert total == 1 + + +# --------------------------------------------------------------------------- +# parse_ids_filter -- the shared parsing/validation (server.py + these tests) +# --------------------------------------------------------------------------- + + +def test_parse_ids_absent_is_none(): + assert parse_ids_filter(None) is None + + +def test_parse_ids_present_but_empty_is_empty_list(): + # `?ids=` and `?ids=,,` parse to [] -> zero-match filter, not None. + assert parse_ids_filter("") == [] + assert parse_ids_filter(",,") == [] + + +def test_parse_ids_dedupes_preserving_order(): + assert parse_ids_filter(f"{_UUID_A},{_UUID_B},{_UUID_A}") == [_UUID_A, _UUID_B] + + +def test_parse_ids_cap_counts_distinct_not_duplicates(): + # A small distinct set repeated far past the cap is still under it. + repeated = ",".join([_UUID_A, _UUID_B] * MAX_JOB_IDS_FILTER) + assert parse_ids_filter(repeated) == [_UUID_A, _UUID_B] + # But more than MAX distinct ids is rejected. + distinct = ",".join( + f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1) + ) + with pytest.raises(JobIdsFilterError): + parse_ids_filter(distinct) + + +def test_parse_ids_invalid_raises_with_payload(): + with pytest.raises(JobIdsFilterError) as exc: + parse_ids_filter(f"{_UUID_A},not-a-uuid") + assert "not-a-uuid" in exc.value.payload["invalid_ids"] + + +# --------------------------------------------------------------------------- +# HTTP contract for the ids query parameter +# --------------------------------------------------------------------------- + + +class FakePromptQueue: + """Minimal stand-in exposing the accessors get_jobs uses.""" + + def __init__(self, running=None, queued=None, history=None): + self._running = list(running or []) + self._queued = list(queued or []) + self._history = dict(history or {}) + + def get_current_queue_volatile(self): + return (list(self._running), list(self._queued)) + + def get_history(self): + return dict(self._history) + + +def make_app(prompt_queue): + """Build an aiohttp app whose handler calls the REAL parse_ids_filter. + + No hand-copied parsing wiring, so this test cannot stay green while the + shipped parsing in server.py regresses -- both go through parse_ids_filter. + """ + + async def get_jobs(request): + try: + ids_filter = parse_ids_filter(request.rel_url.query.get('ids')) + except JobIdsFilterError as e: + return web.json_response(e.payload, status=400) + + running, queued = prompt_queue.get_current_queue_volatile() + history = prompt_queue.get_history() + + jobs, total = get_all_jobs(running, queued, history, ids=ids_filter) + + return web.json_response({ + 'jobs': jobs, + 'pagination': {'total': total}, + }) + + app = web.Application() + app.router.add_get('/api/jobs', get_jobs) + return app + + +@pytest.fixture +def queue(): + return FakePromptQueue( + running=[make_queue_item(_UUID_A)], + queued=[make_queue_item(_UUID_B)], + history={_UUID_C: make_history_item()}, + ) + + +@pytest.mark.asyncio +async def test_http_ids_filter_narrows(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_C}") + assert resp.status == 200 + body = await resp.json() + assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_C} + + +@pytest.mark.asyncio +async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_MISSING}") + assert resp.status == 200 + body = await resp.json() + assert {j["id"] for j in body["jobs"]} == {_UUID_A} + + +@pytest.mark.asyncio +async def test_http_ids_over_limit_returns_400(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + # Distinct ids past the cap. (Repeats of one id are de-duped and would NOT + # trip the cap -- see test_parse_ids_cap_counts_distinct_not_duplicates.) + too_many = ",".join( + f"{i:08d}-0000-4000-8000-000000000000" for i in range(MAX_JOB_IDS_FILTER + 1) + ) + resp = await client.get(f"/api/jobs?ids={too_many}") + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_http_ids_invalid_id_returns_400(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get(f"/api/jobs?ids={_UUID_A},not-a-uuid") + assert resp.status == 400 + body = await resp.json() + assert "not-a-uuid" in body["invalid_ids"] + + +@pytest.mark.asyncio +async def test_http_ids_absent_returns_all(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get("/api/jobs") + assert resp.status == 200 + body = await resp.json() + assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C} + + +@pytest.mark.asyncio +async def test_http_ids_present_but_empty_returns_none(aiohttp_client, queue): + """`?ids=` (present but empty) is a zero-match filter, not "return all".""" + client = await aiohttp_client(make_app(queue)) + + resp = await client.get("/api/jobs?ids=") + assert resp.status == 200 + body = await resp.json() + assert body["jobs"] == []