From 44fb02e5105aef45987be1f131cd822305195b86 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 30 Jun 2026 14:58:15 -0700 Subject: [PATCH] feat: add ids filter to GET /api/jobs for batch polling Add an optional comma-separated `ids` query parameter to GET /api/jobs so a caller can poll a known set of jobs in a single request instead of one call per job. The filter narrows the result to the requested job ids and composes with the existing status / workflow_id filters; an absent or empty `ids` means no filter. The handler caps the request at 100 ids (checked before validation) and validates each id with the existing validate_job_id helper, returning HTTP 400 on overflow or a malformed id. get_all_jobs gains an optional ids argument that narrows the normalized job list by id. Adds unit coverage for the filter logic and the endpoint's validation contract. --- comfy_execution/jobs.py | 11 + server.py | 28 +++ tests-unit/jobs_list_test/__init__.py | 0 tests-unit/jobs_list_test/jobs_list_test.py | 234 ++++++++++++++++++++ 4 files changed, 273 insertions(+) create mode 100644 tests-unit/jobs_list_test/__init__.py create mode 100644 tests-unit/jobs_list_test/jobs_list_test.py diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fa3ab0faf..f0f9b32ec 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -31,6 +31,11 @@ class JobStatus: ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] +# Maximum number of ids accepted by the `ids` filter on the jobs listing. +# Bounds the work a single batch-poll request can ask for. +MAX_JOB_IDS_FILTER = 100 + + def validate_job_id(value) -> str: """Validate a client-supplied job (prompt) id. @@ -362,6 +367,7 @@ def get_all_jobs( history: dict, status_filter: Optional[list[str]] = None, workflow_id: Optional[str] = None, + ids: Optional[list[str]] = None, sort_by: str = "created_at", sort_order: str = "desc", limit: Optional[int] = None, @@ -376,6 +382,7 @@ def get_all_jobs( history: Dict of history items keyed by prompt_id status_filter: List of statuses to include (from JobStatus.ALL) workflow_id: Filter by workflow ID + ids: Restrict the result to these job ids (None/empty = no filter) sort_by: Field to sort by ('created_at', 'execution_duration') sort_order: 'asc' or 'desc' limit: Maximum number of items to return @@ -408,6 +415,10 @@ def get_all_jobs( if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] + if ids: + id_set = set(ids) + jobs = [j for j in jobs if j['id'] in id_set] + jobs = apply_sorting(jobs, sort_by, sort_order) total_count = len(jobs) diff --git a/server.py b/server.py index 361850f38..cb7916d6b 100644 --- a/server.py +++ b/server.py @@ -16,6 +16,7 @@ from comfy_execution.jobs import ( cancel_job, CANCEL_PENDING, CANCEL_RUNNING, + MAX_JOB_IDS_FILTER, ) import uuid import urllib @@ -791,6 +792,7 @@ class PromptServer(): Query parameters: status: Filter by status (comma-separated): pending, in_progress, completed, failed workflow_id: Filter by workflow ID + ids: Filter by job id (comma-separated UUIDs, max 100) sort_by: Sort field: created_at (default), execution_duration sort_order: Sort direction: asc, desc (default) limit: Max items to return (positive integer) @@ -800,6 +802,7 @@ class PromptServer(): status_param = query.get('status') workflow_id = query.get('workflow_id') + ids_param = query.get('ids') sort_by = query.get('sort_by', 'created_at').lower() sort_order = query.get('sort_order', 'desc').lower() @@ -813,6 +816,30 @@ class PromptServer(): status=400 ) + # Optional batch filter: narrow the result to a known set of job ids + # (e.g. polling a submitted batch in one request). Absent/empty means + # no filter. Cap the count before validating the full list so an + # oversized request fails fast, then reject any malformed id with 400. + ids_filter = None + if ids_param: + ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()] + if len(ids_filter) > MAX_JOB_IDS_FILTER: + return web.json_response( + {"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}, + status=400 + ) + invalid_ids = [] + for jid in ids_filter: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid) + if invalid_ids: + return web.json_response( + {"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400 + ) + if sort_by not in {'created_at', 'execution_duration'}: return web.json_response( {"error": "sort_by must be 'created_at' or 'execution_duration'"}, @@ -864,6 +891,7 @@ class PromptServer(): running, queued, history, status_filter=status_filter, workflow_id=workflow_id, + ids=ids_filter, sort_by=sort_by, sort_order=sort_order, limit=limit, diff --git a/tests-unit/jobs_list_test/__init__.py b/tests-unit/jobs_list_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/jobs_list_test/jobs_list_test.py b/tests-unit/jobs_list_test/jobs_list_test.py new file mode 100644 index 000000000..bfeb4d727 --- /dev/null +++ b/tests-unit/jobs_list_test/jobs_list_test.py @@ -0,0 +1,234 @@ +"""Tests for the ``ids`` batch filter on the jobs listing endpoint. + +Covers both layers: + +* the pure ``comfy_execution.jobs.get_all_jobs`` filtering logic (the ``ids`` + argument narrows the result, composes with ``status_filter``, and silently + ignores ids that match nothing), and + +* the HTTP contract of ``GET /api/jobs`` for the ``ids`` query parameter + (a valid set narrows the response, an oversized set or a malformed id is + rejected with 400). + +As in ``jobs_cancel_test``, the HTTP layer is exercised against a small +aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in +``server.py``, driven by a fake queue. This keeps the test free of the heavy +ComfyUI runtime (torch, nodes, ...) while still testing the real contract. +""" + +import pytest +from aiohttp import web + +from comfy_execution.jobs import ( + JobStatus, + MAX_JOB_IDS_FILTER, + get_all_jobs, + validate_job_id, +) + +# Canonical UUID ids (the endpoint validates UUID format). +_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa" +_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb" +_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc" +_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff" + + +def make_queue_item(prompt_id, priority=0): + """Build a queue tuple shaped like the real ones (5 elements, id at index 1).""" + return (priority, prompt_id, {}, {}, []) + + +def make_history_item(status_str="success"): + """Build a history item dict shaped like the real ones.""" + return { + "prompt": (0, "", {}, {}, []), + "status": {"status_str": status_str, "messages": []}, + "outputs": {}, + } + + +# --------------------------------------------------------------------------- +# Pure get_all_jobs filtering logic +# --------------------------------------------------------------------------- + + +def test_ids_filter_returns_only_requested(): + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + history = {_UUID_C: make_history_item()} + + jobs, total = get_all_jobs(running, queued, history, ids=[_UUID_A, _UUID_C]) + + returned = {j["id"] for j in jobs} + assert returned == {_UUID_A, _UUID_C} + assert total == 2 + assert _UUID_B not in returned + + +def test_ids_filter_absent_returns_all(): + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + history = {_UUID_C: make_history_item()} + + jobs, total = get_all_jobs(running, queued, history) + + assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B, _UUID_C} + assert total == 3 + + +def test_ids_filter_empty_list_returns_all(): + """An empty list behaves like no filter (matches how status/workflow_id behave).""" + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + + jobs, _ = get_all_jobs(running, queued, {}, ids=[]) + + assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B} + + +def test_ids_filter_unknown_id_silently_absent(): + """An id that matches nothing is simply not present (no error).""" + running = [make_queue_item(_UUID_A)] + + jobs, total = get_all_jobs(running, [], {}, ids=[_UUID_A, _UUID_MISSING]) + + assert {j["id"] for j in jobs} == {_UUID_A} + assert total == 1 + + +def test_ids_filter_composes_with_status(): + """ids only narrows; it composes with the status filter.""" + running = [make_queue_item(_UUID_A)] + queued = [make_queue_item(_UUID_B)] + history = {_UUID_C: make_history_item()} + + # Request A and C by id, but restrict to in_progress only -> just A. + jobs, total = get_all_jobs( + running, queued, history, + status_filter=[JobStatus.IN_PROGRESS], + ids=[_UUID_A, _UUID_C], + ) + + assert {j["id"] for j in jobs} == {_UUID_A} + assert total == 1 + + +# --------------------------------------------------------------------------- +# HTTP contract for the ids query parameter +# --------------------------------------------------------------------------- + + +class FakePromptQueue: + """Minimal stand-in exposing the accessors get_jobs uses.""" + + def __init__(self, running=None, queued=None, history=None): + self._running = list(running or []) + self._queued = list(queued or []) + self._history = dict(history or {}) + + def get_current_queue_volatile(self): + return (list(self._running), list(self._queued)) + + def get_history(self): + return dict(self._history) + + +def make_app(prompt_queue): + """Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring.""" + + async def get_jobs(request): + query = request.rel_url.query + + ids_param = query.get('ids') + + ids_filter = None + if ids_param: + ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()] + if len(ids_filter) > MAX_JOB_IDS_FILTER: + return web.json_response( + {"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"}, + status=400 + ) + invalid_ids = [] + for jid in ids_filter: + try: + validate_job_id(jid) + except (ValueError, AttributeError): + invalid_ids.append(jid) + if invalid_ids: + return web.json_response( + {"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids}, + status=400 + ) + + running, queued = prompt_queue.get_current_queue_volatile() + history = prompt_queue.get_history() + + jobs, total = get_all_jobs(running, queued, history, ids=ids_filter) + + return web.json_response({ + 'jobs': jobs, + 'pagination': {'total': total}, + }) + + app = web.Application() + app.router.add_get('/api/jobs', get_jobs) + return app + + +@pytest.fixture +def queue(): + return FakePromptQueue( + running=[make_queue_item(_UUID_A)], + queued=[make_queue_item(_UUID_B)], + history={_UUID_C: make_history_item()}, + ) + + +@pytest.mark.asyncio +async def test_http_ids_filter_narrows(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_C}") + assert resp.status == 200 + body = await resp.json() + assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_C} + + +@pytest.mark.asyncio +async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_MISSING}") + assert resp.status == 200 + body = await resp.json() + assert {j["id"] for j in body["jobs"]} == {_UUID_A} + + +@pytest.mark.asyncio +async def test_http_ids_over_limit_returns_400(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1)) + resp = await client.get(f"/api/jobs?ids={too_many}") + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_http_ids_invalid_id_returns_400(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get(f"/api/jobs?ids={_UUID_A},not-a-uuid") + assert resp.status == 400 + body = await resp.json() + assert "not-a-uuid" in body["invalid_ids"] + + +@pytest.mark.asyncio +async def test_http_ids_absent_returns_all(aiohttp_client, queue): + client = await aiohttp_client(make_app(queue)) + + resp = await client.get("/api/jobs") + assert resp.status == 200 + body = await resp.json() + assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}