feat: add ids filter to GET /api/jobs for batch polling

Add an optional comma-separated `ids` query parameter to GET /api/jobs so a
caller can poll a known set of jobs in a single request instead of one call
per job. The filter narrows the result to the requested job ids and composes
with the existing status / workflow_id filters; an absent or empty `ids` means
no filter.

The handler caps the request at 100 ids (checked before validation) and
validates each id with the existing validate_job_id helper, returning HTTP 400
on overflow or a malformed id. get_all_jobs gains an optional ids argument that
narrows the normalized job list by id.

Adds unit coverage for the filter logic and the endpoint's validation contract.
This commit is contained in:
Matt Miller 2026-06-30 14:58:15 -07:00
parent 50e5270b86
commit 44fb02e510
4 changed files with 273 additions and 0 deletions

View File

@ -31,6 +31,11 @@ class JobStatus:
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Maximum number of ids accepted by the `ids` filter on the jobs listing.
# Bounds the work a single batch-poll request can ask for.
MAX_JOB_IDS_FILTER = 100
def validate_job_id(value) -> str:
"""Validate a client-supplied job (prompt) id.
@ -362,6 +367,7 @@ def get_all_jobs(
history: dict,
status_filter: Optional[list[str]] = None,
workflow_id: Optional[str] = None,
ids: Optional[list[str]] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
limit: Optional[int] = None,
@ -376,6 +382,7 @@ def get_all_jobs(
history: Dict of history items keyed by prompt_id
status_filter: List of statuses to include (from JobStatus.ALL)
workflow_id: Filter by workflow ID
ids: Restrict the result to these job ids (None/empty = no filter)
sort_by: Field to sort by ('created_at', 'execution_duration')
sort_order: 'asc' or 'desc'
limit: Maximum number of items to return
@ -408,6 +415,10 @@ def get_all_jobs(
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
if ids:
id_set = set(ids)
jobs = [j for j in jobs if j['id'] in id_set]
jobs = apply_sorting(jobs, sort_by, sort_order)
total_count = len(jobs)

View File

@ -16,6 +16,7 @@ from comfy_execution.jobs import (
cancel_job,
CANCEL_PENDING,
CANCEL_RUNNING,
MAX_JOB_IDS_FILTER,
)
import uuid
import urllib
@ -791,6 +792,7 @@ class PromptServer():
Query parameters:
status: Filter by status (comma-separated): pending, in_progress, completed, failed
workflow_id: Filter by workflow ID
ids: Filter by job id (comma-separated UUIDs, max 100)
sort_by: Sort field: created_at (default), execution_duration
sort_order: Sort direction: asc, desc (default)
limit: Max items to return (positive integer)
@ -800,6 +802,7 @@ class PromptServer():
status_param = query.get('status')
workflow_id = query.get('workflow_id')
ids_param = query.get('ids')
sort_by = query.get('sort_by', 'created_at').lower()
sort_order = query.get('sort_order', 'desc').lower()
@ -813,6 +816,30 @@ class PromptServer():
status=400
)
# Optional batch filter: narrow the result to a known set of job ids
# (e.g. polling a submitted batch in one request). Absent/empty means
# no filter. Cap the count before validating the full list so an
# oversized request fails fast, then reject any malformed id with 400.
ids_filter = None
if ids_param:
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
if len(ids_filter) > MAX_JOB_IDS_FILTER:
return web.json_response(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
status=400
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
return web.json_response(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400
)
if sort_by not in {'created_at', 'execution_duration'}:
return web.json_response(
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
@ -864,6 +891,7 @@ class PromptServer():
running, queued, history,
status_filter=status_filter,
workflow_id=workflow_id,
ids=ids_filter,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,

View File

View File

@ -0,0 +1,234 @@
"""Tests for the ``ids`` batch filter on the jobs listing endpoint.
Covers both layers:
* the pure ``comfy_execution.jobs.get_all_jobs`` filtering logic (the ``ids``
argument narrows the result, composes with ``status_filter``, and silently
ignores ids that match nothing), and
* the HTTP contract of ``GET /api/jobs`` for the ``ids`` query parameter
(a valid set narrows the response, an oversized set or a malformed id is
rejected with 400).
As in ``jobs_cancel_test``, the HTTP layer is exercised against a small
aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in
``server.py``, driven by a fake queue. This keeps the test free of the heavy
ComfyUI runtime (torch, nodes, ...) while still testing the real contract.
"""
import pytest
from aiohttp import web
from comfy_execution.jobs import (
JobStatus,
MAX_JOB_IDS_FILTER,
get_all_jobs,
validate_job_id,
)
# Canonical UUID ids (the endpoint validates UUID format).
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
def make_queue_item(prompt_id, priority=0):
"""Build a queue tuple shaped like the real ones (5 elements, id at index 1)."""
return (priority, prompt_id, {}, {}, [])
def make_history_item(status_str="success"):
"""Build a history item dict shaped like the real ones."""
return {
"prompt": (0, "", {}, {}, []),
"status": {"status_str": status_str, "messages": []},
"outputs": {},
}
# ---------------------------------------------------------------------------
# Pure get_all_jobs filtering logic
# ---------------------------------------------------------------------------
def test_ids_filter_returns_only_requested():
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
history = {_UUID_C: make_history_item()}
jobs, total = get_all_jobs(running, queued, history, ids=[_UUID_A, _UUID_C])
returned = {j["id"] for j in jobs}
assert returned == {_UUID_A, _UUID_C}
assert total == 2
assert _UUID_B not in returned
def test_ids_filter_absent_returns_all():
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
history = {_UUID_C: make_history_item()}
jobs, total = get_all_jobs(running, queued, history)
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B, _UUID_C}
assert total == 3
def test_ids_filter_empty_list_returns_all():
"""An empty list behaves like no filter (matches how status/workflow_id behave)."""
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
jobs, _ = get_all_jobs(running, queued, {}, ids=[])
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B}
def test_ids_filter_unknown_id_silently_absent():
"""An id that matches nothing is simply not present (no error)."""
running = [make_queue_item(_UUID_A)]
jobs, total = get_all_jobs(running, [], {}, ids=[_UUID_A, _UUID_MISSING])
assert {j["id"] for j in jobs} == {_UUID_A}
assert total == 1
def test_ids_filter_composes_with_status():
"""ids only narrows; it composes with the status filter."""
running = [make_queue_item(_UUID_A)]
queued = [make_queue_item(_UUID_B)]
history = {_UUID_C: make_history_item()}
# Request A and C by id, but restrict to in_progress only -> just A.
jobs, total = get_all_jobs(
running, queued, history,
status_filter=[JobStatus.IN_PROGRESS],
ids=[_UUID_A, _UUID_C],
)
assert {j["id"] for j in jobs} == {_UUID_A}
assert total == 1
# ---------------------------------------------------------------------------
# HTTP contract for the ids query parameter
# ---------------------------------------------------------------------------
class FakePromptQueue:
"""Minimal stand-in exposing the accessors get_jobs uses."""
def __init__(self, running=None, queued=None, history=None):
self._running = list(running or [])
self._queued = list(queued or [])
self._history = dict(history or {})
def get_current_queue_volatile(self):
return (list(self._running), list(self._queued))
def get_history(self):
return dict(self._history)
def make_app(prompt_queue):
"""Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring."""
async def get_jobs(request):
query = request.rel_url.query
ids_param = query.get('ids')
ids_filter = None
if ids_param:
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
if len(ids_filter) > MAX_JOB_IDS_FILTER:
return web.json_response(
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
status=400
)
invalid_ids = []
for jid in ids_filter:
try:
validate_job_id(jid)
except (ValueError, AttributeError):
invalid_ids.append(jid)
if invalid_ids:
return web.json_response(
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
status=400
)
running, queued = prompt_queue.get_current_queue_volatile()
history = prompt_queue.get_history()
jobs, total = get_all_jobs(running, queued, history, ids=ids_filter)
return web.json_response({
'jobs': jobs,
'pagination': {'total': total},
})
app = web.Application()
app.router.add_get('/api/jobs', get_jobs)
return app
@pytest.fixture
def queue():
return FakePromptQueue(
running=[make_queue_item(_UUID_A)],
queued=[make_queue_item(_UUID_B)],
history={_UUID_C: make_history_item()},
)
@pytest.mark.asyncio
async def test_http_ids_filter_narrows(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_C}")
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_C}
@pytest.mark.asyncio
async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_MISSING}")
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A}
@pytest.mark.asyncio
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1))
resp = await client.get(f"/api/jobs?ids={too_many}")
assert resp.status == 400
@pytest.mark.asyncio
async def test_http_ids_invalid_id_returns_400(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get(f"/api/jobs?ids={_UUID_A},not-a-uuid")
assert resp.status == 400
body = await resp.json()
assert "not-a-uuid" in body["invalid_ids"]
@pytest.mark.asyncio
async def test_http_ids_absent_returns_all(aiohttp_client, queue):
client = await aiohttp_client(make_app(queue))
resp = await client.get("/api/jobs")
assert resp.status == 200
body = await resp.json()
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}