mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
feat: add ids filter to GET /api/jobs for batch polling
Add an optional comma-separated `ids` query parameter to GET /api/jobs so a caller can poll a known set of jobs in a single request instead of one call per job. The filter narrows the result to the requested job ids and composes with the existing status / workflow_id filters; an absent or empty `ids` means no filter. The handler caps the request at 100 ids (checked before validation) and validates each id with the existing validate_job_id helper, returning HTTP 400 on overflow or a malformed id. get_all_jobs gains an optional ids argument that narrows the normalized job list by id. Adds unit coverage for the filter logic and the endpoint's validation contract.
This commit is contained in:
parent
50e5270b86
commit
44fb02e510
@ -31,6 +31,11 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
# Maximum number of ids accepted by the `ids` filter on the jobs listing.
|
||||
# Bounds the work a single batch-poll request can ask for.
|
||||
MAX_JOB_IDS_FILTER = 100
|
||||
|
||||
|
||||
def validate_job_id(value) -> str:
|
||||
"""Validate a client-supplied job (prompt) id.
|
||||
|
||||
@ -362,6 +367,7 @@ def get_all_jobs(
|
||||
history: dict,
|
||||
status_filter: Optional[list[str]] = None,
|
||||
workflow_id: Optional[str] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
limit: Optional[int] = None,
|
||||
@ -376,6 +382,7 @@ def get_all_jobs(
|
||||
history: Dict of history items keyed by prompt_id
|
||||
status_filter: List of statuses to include (from JobStatus.ALL)
|
||||
workflow_id: Filter by workflow ID
|
||||
ids: Restrict the result to these job ids (None/empty = no filter)
|
||||
sort_by: Field to sort by ('created_at', 'execution_duration')
|
||||
sort_order: 'asc' or 'desc'
|
||||
limit: Maximum number of items to return
|
||||
@ -408,6 +415,10 @@ def get_all_jobs(
|
||||
if workflow_id:
|
||||
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]
|
||||
|
||||
if ids:
|
||||
id_set = set(ids)
|
||||
jobs = [j for j in jobs if j['id'] in id_set]
|
||||
|
||||
jobs = apply_sorting(jobs, sort_by, sort_order)
|
||||
|
||||
total_count = len(jobs)
|
||||
|
||||
28
server.py
28
server.py
@ -16,6 +16,7 @@ from comfy_execution.jobs import (
|
||||
cancel_job,
|
||||
CANCEL_PENDING,
|
||||
CANCEL_RUNNING,
|
||||
MAX_JOB_IDS_FILTER,
|
||||
)
|
||||
import uuid
|
||||
import urllib
|
||||
@ -791,6 +792,7 @@ class PromptServer():
|
||||
Query parameters:
|
||||
status: Filter by status (comma-separated): pending, in_progress, completed, failed
|
||||
workflow_id: Filter by workflow ID
|
||||
ids: Filter by job id (comma-separated UUIDs, max 100)
|
||||
sort_by: Sort field: created_at (default), execution_duration
|
||||
sort_order: Sort direction: asc, desc (default)
|
||||
limit: Max items to return (positive integer)
|
||||
@ -800,6 +802,7 @@ class PromptServer():
|
||||
|
||||
status_param = query.get('status')
|
||||
workflow_id = query.get('workflow_id')
|
||||
ids_param = query.get('ids')
|
||||
sort_by = query.get('sort_by', 'created_at').lower()
|
||||
sort_order = query.get('sort_order', 'desc').lower()
|
||||
|
||||
@ -813,6 +816,30 @@ class PromptServer():
|
||||
status=400
|
||||
)
|
||||
|
||||
# Optional batch filter: narrow the result to a known set of job ids
|
||||
# (e.g. polling a submitted batch in one request). Absent/empty means
|
||||
# no filter. Cap the count before validating the full list so an
|
||||
# oversized request fails fast, then reject any malformed id with 400.
|
||||
ids_filter = None
|
||||
if ids_param:
|
||||
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
|
||||
if len(ids_filter) > MAX_JOB_IDS_FILTER:
|
||||
return web.json_response(
|
||||
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
|
||||
status=400
|
||||
)
|
||||
invalid_ids = []
|
||||
for jid in ids_filter:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid)
|
||||
if invalid_ids:
|
||||
return web.json_response(
|
||||
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||
status=400
|
||||
)
|
||||
|
||||
if sort_by not in {'created_at', 'execution_duration'}:
|
||||
return web.json_response(
|
||||
{"error": "sort_by must be 'created_at' or 'execution_duration'"},
|
||||
@ -864,6 +891,7 @@ class PromptServer():
|
||||
running, queued, history,
|
||||
status_filter=status_filter,
|
||||
workflow_id=workflow_id,
|
||||
ids=ids_filter,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
limit=limit,
|
||||
|
||||
0
tests-unit/jobs_list_test/__init__.py
Normal file
0
tests-unit/jobs_list_test/__init__.py
Normal file
234
tests-unit/jobs_list_test/jobs_list_test.py
Normal file
234
tests-unit/jobs_list_test/jobs_list_test.py
Normal file
@ -0,0 +1,234 @@
|
||||
"""Tests for the ``ids`` batch filter on the jobs listing endpoint.
|
||||
|
||||
Covers both layers:
|
||||
|
||||
* the pure ``comfy_execution.jobs.get_all_jobs`` filtering logic (the ``ids``
|
||||
argument narrows the result, composes with ``status_filter``, and silently
|
||||
ignores ids that match nothing), and
|
||||
|
||||
* the HTTP contract of ``GET /api/jobs`` for the ``ids`` query parameter
|
||||
(a valid set narrows the response, an oversized set or a malformed id is
|
||||
rejected with 400).
|
||||
|
||||
As in ``jobs_cancel_test``, the HTTP layer is exercised against a small
|
||||
aiohttp app whose handler is a faithful copy of the ``ids``-parsing wiring in
|
||||
``server.py``, driven by a fake queue. This keeps the test free of the heavy
|
||||
ComfyUI runtime (torch, nodes, ...) while still testing the real contract.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
MAX_JOB_IDS_FILTER,
|
||||
get_all_jobs,
|
||||
validate_job_id,
|
||||
)
|
||||
|
||||
# Canonical UUID ids (the endpoint validates UUID format).
|
||||
_UUID_A = "aaaaaaaa-aaaa-4aaa-aaaa-aaaaaaaaaaaa"
|
||||
_UUID_B = "bbbbbbbb-bbbb-4bbb-bbbb-bbbbbbbbbbbb"
|
||||
_UUID_C = "cccccccc-cccc-4ccc-cccc-cccccccccccc"
|
||||
_UUID_MISSING = "ffffffff-ffff-4fff-ffff-ffffffffffff"
|
||||
|
||||
|
||||
def make_queue_item(prompt_id, priority=0):
|
||||
"""Build a queue tuple shaped like the real ones (5 elements, id at index 1)."""
|
||||
return (priority, prompt_id, {}, {}, [])
|
||||
|
||||
|
||||
def make_history_item(status_str="success"):
|
||||
"""Build a history item dict shaped like the real ones."""
|
||||
return {
|
||||
"prompt": (0, "", {}, {}, []),
|
||||
"status": {"status_str": status_str, "messages": []},
|
||||
"outputs": {},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure get_all_jobs filtering logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ids_filter_returns_only_requested():
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
history = {_UUID_C: make_history_item()}
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, history, ids=[_UUID_A, _UUID_C])
|
||||
|
||||
returned = {j["id"] for j in jobs}
|
||||
assert returned == {_UUID_A, _UUID_C}
|
||||
assert total == 2
|
||||
assert _UUID_B not in returned
|
||||
|
||||
|
||||
def test_ids_filter_absent_returns_all():
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
history = {_UUID_C: make_history_item()}
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, history)
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B, _UUID_C}
|
||||
assert total == 3
|
||||
|
||||
|
||||
def test_ids_filter_empty_list_returns_all():
|
||||
"""An empty list behaves like no filter (matches how status/workflow_id behave)."""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
|
||||
jobs, _ = get_all_jobs(running, queued, {}, ids=[])
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A, _UUID_B}
|
||||
|
||||
|
||||
def test_ids_filter_unknown_id_silently_absent():
|
||||
"""An id that matches nothing is simply not present (no error)."""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
|
||||
jobs, total = get_all_jobs(running, [], {}, ids=[_UUID_A, _UUID_MISSING])
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A}
|
||||
assert total == 1
|
||||
|
||||
|
||||
def test_ids_filter_composes_with_status():
|
||||
"""ids only narrows; it composes with the status filter."""
|
||||
running = [make_queue_item(_UUID_A)]
|
||||
queued = [make_queue_item(_UUID_B)]
|
||||
history = {_UUID_C: make_history_item()}
|
||||
|
||||
# Request A and C by id, but restrict to in_progress only -> just A.
|
||||
jobs, total = get_all_jobs(
|
||||
running, queued, history,
|
||||
status_filter=[JobStatus.IN_PROGRESS],
|
||||
ids=[_UUID_A, _UUID_C],
|
||||
)
|
||||
|
||||
assert {j["id"] for j in jobs} == {_UUID_A}
|
||||
assert total == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP contract for the ids query parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FakePromptQueue:
|
||||
"""Minimal stand-in exposing the accessors get_jobs uses."""
|
||||
|
||||
def __init__(self, running=None, queued=None, history=None):
|
||||
self._running = list(running or [])
|
||||
self._queued = list(queued or [])
|
||||
self._history = dict(history or {})
|
||||
|
||||
def get_current_queue_volatile(self):
|
||||
return (list(self._running), list(self._queued))
|
||||
|
||||
def get_history(self):
|
||||
return dict(self._history)
|
||||
|
||||
|
||||
def make_app(prompt_queue):
|
||||
"""Build an aiohttp app whose handler mirrors server.py's get_jobs ids wiring."""
|
||||
|
||||
async def get_jobs(request):
|
||||
query = request.rel_url.query
|
||||
|
||||
ids_param = query.get('ids')
|
||||
|
||||
ids_filter = None
|
||||
if ids_param:
|
||||
ids_filter = [i.strip() for i in ids_param.split(',') if i.strip()]
|
||||
if len(ids_filter) > MAX_JOB_IDS_FILTER:
|
||||
return web.json_response(
|
||||
{"error": f"ids must contain at most {MAX_JOB_IDS_FILTER} values"},
|
||||
status=400
|
||||
)
|
||||
invalid_ids = []
|
||||
for jid in ids_filter:
|
||||
try:
|
||||
validate_job_id(jid)
|
||||
except (ValueError, AttributeError):
|
||||
invalid_ids.append(jid)
|
||||
if invalid_ids:
|
||||
return web.json_response(
|
||||
{"error": "ids contains invalid id(s)", "invalid_ids": invalid_ids},
|
||||
status=400
|
||||
)
|
||||
|
||||
running, queued = prompt_queue.get_current_queue_volatile()
|
||||
history = prompt_queue.get_history()
|
||||
|
||||
jobs, total = get_all_jobs(running, queued, history, ids=ids_filter)
|
||||
|
||||
return web.json_response({
|
||||
'jobs': jobs,
|
||||
'pagination': {'total': total},
|
||||
})
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get('/api/jobs', get_jobs)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue():
|
||||
return FakePromptQueue(
|
||||
running=[make_queue_item(_UUID_A)],
|
||||
queued=[make_queue_item(_UUID_B)],
|
||||
history={_UUID_C: make_history_item()},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_filter_narrows(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_C}")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_C}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_unknown_id_is_not_an_error(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get(f"/api/jobs?ids={_UUID_A},{_UUID_MISSING}")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_over_limit_returns_400(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
too_many = ",".join([_UUID_A] * (MAX_JOB_IDS_FILTER + 1))
|
||||
resp = await client.get(f"/api/jobs?ids={too_many}")
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_invalid_id_returns_400(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get(f"/api/jobs?ids={_UUID_A},not-a-uuid")
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "not-a-uuid" in body["invalid_ids"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_ids_absent_returns_all(aiohttp_client, queue):
|
||||
client = await aiohttp_client(make_app(queue))
|
||||
|
||||
resp = await client.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert {j["id"] for j in body["jobs"]} == {_UUID_A, _UUID_B, _UUID_C}
|
||||
Loading…
Reference in New Issue
Block a user