diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 20ebae155..7d7362774 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking. """ import uuid -from typing import Optional +from typing import Callable, Optional from comfy_api.internal import prune_dict +# Result of classifying a job for cancellation. +# 'running' -> job is currently executing (interrupt it) +# 'pending' -> job is queued but not started (dequeue it) +# 'terminal' -> job already finished (present in history); cancel is a no-op +# 'unknown' -> job id is not present anywhere +CANCEL_RUNNING = 'running' +CANCEL_PENDING = 'pending' +CANCEL_TERMINAL = 'terminal' +CANCEL_UNKNOWN = 'unknown' + + class JobStatus: """Job status constants.""" PENDING = 'pending' @@ -407,3 +418,52 @@ def get_all_jobs( jobs = jobs[:limit] return (jobs, total_count) + + +def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str: + """Classify a job id for cancellation. + + Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN. + + Queue items are tuples whose second element (index 1) is the prompt_id. + History is a dict keyed by prompt_id, so a job present there has already + finished and cancelling it is a no-op. + """ + for item in running: + if item[1] == prompt_id: + return CANCEL_RUNNING + for item in queued: + if item[1] == prompt_id: + return CANCEL_PENDING + if prompt_id in history: + return CANCEL_TERMINAL + return CANCEL_UNKNOWN + + +def cancel_job( + prompt_id: str, + running: list, + queued: list, + history: dict, + interrupt: Callable[[], None], + dequeue: Callable[[str], bool], +) -> str: + """Cancel a single job by id, regardless of state. + + Maps the cancel onto the runtime's existing mechanics: + - a running job is interrupted via ``interrupt`` + - a pending job is removed from the queue via ``dequeue`` + - a job that already finished (terminal) is a no-op + - an unknown id is a no-op (callers that need fail-fast behaviour should + validate ids up front with ``classify_job_for_cancel``) + + Returns the classification that was acted on (one of the CANCEL_* values), + so callers can log or report what happened. + """ + classification = classify_job_for_cancel(prompt_id, running, queued, history) + if classification == CANCEL_RUNNING: + interrupt() + elif classification == CANCEL_PENDING: + dequeue(prompt_id) + # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. + return classification diff --git a/openapi.yaml b/openapi.yaml index 6e203b1cd..c1263e7ed 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -673,6 +673,49 @@ components: - created_at - updated_at type: object + JobsBatchCancelNotFoundResponse: + description: | + Returned with 404 from POST /api/jobs/cancel when one or more + requested job ids are unknown. The batch is fail-fast, so no job + was cancelled. + properties: + error: + description: Human-readable error message + type: string + unknown_ids: + description: The subset of requested job ids that were not found + items: + type: string + type: array + required: + - error + - unknown_ids + type: object + JobsBatchCancelRequest: + additionalProperties: false + description: Request body for batch job cancellation + properties: + job_ids: + description: Ids (UUIDs) of the jobs to cancel + items: + format: uuid + type: string + type: array + required: + - job_ids + type: object + JobsBatchCancelResponse: + description: Response for POST /api/jobs/cancel when all requested jobs were known. + properties: + cancelled: + description: | + True when a cancel event was dispatched for at least one job in + the batch. False when every requested job was already in a + terminal state (the call is still 200 — idempotent). + type: boolean + required: + - cancelled + type: object JobsListResponse: description: Paginated list of jobs for the authenticated user. properties: @@ -2855,6 +2898,56 @@ paths: summary: List jobs with pagination and filtering tags: - workflow + /api/jobs/cancel: + post: + description: | + Cancel a batch of jobs by id, regardless of each job's state. + + Fail-fast: if any provided id is unknown (not running, pending, or + present in history) the request returns 404 and no job is cancelled, + so the call has no partial side effects. When every id is known, all + jobs are cancelled and the call returns 200. + operationId: cancelJobs + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/JobsBatchCancelRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/JobsBatchCancelResponse' + description: Success - All requested jobs were cancelled (or were already terminal) + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad Request - body is not valid JSON or job_ids is missing/not a list + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/JobsBatchCancelNotFoundResponse' + description: Not Found - one or more job ids are unknown; no job was cancelled + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error - cancellation failed + summary: Cancel a batch of jobs + tags: + - workflow /api/jobs/{job_id}: get: description: | diff --git a/server.py b/server.py index 6b0029adf..f5adb25c4 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,17 @@ import time import nodes import folder_paths import execution -from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id +from comfy_execution.jobs import ( + JobStatus, + get_job, + get_all_jobs, + validate_job_id, + cancel_job, + classify_job_for_cancel, + CANCEL_PENDING, + CANCEL_RUNNING, + CANCEL_UNKNOWN, +) import uuid import urllib import json @@ -899,6 +909,96 @@ class PromptServer(): return web.json_response(job) + def _cancel_job_by_id(job_id): + """Cancel a single job by id using the queue's existing mechanics. + + Running jobs are interrupted (same mechanism as /interrupt); pending + jobs are dequeued (same mechanism as /queue {"delete": [...]}). + Already-finished or unknown ids are no-ops. State-agnostic. + + Returns True when a cancel was actually dispatched (running or + pending job), False when the call was a no-op (terminal/unknown id). + """ + running, queued = self.prompt_queue.get_current_queue() + history = self.prompt_queue.get_history() + + def interrupt(): + logging.info(f"Cancelling running prompt {job_id}") + nodes.interrupt_processing() + + def dequeue(prompt_id): + logging.info(f"Cancelling pending prompt {prompt_id}") + return self.prompt_queue.delete_queue_item(lambda a: a[1] == prompt_id) + + classification = cancel_job(job_id, running, queued, history, interrupt, dequeue) + return classification in (CANCEL_RUNNING, CANCEL_PENDING) + + @routes.post("/api/jobs/{job_id}/cancel") + async def cancel_job_by_id(request): + """Cancel a single job by id, regardless of state. + + Idempotent: cancelling a job that has already finished, or an id + that is not known, returns 200 with {"cancelled": false} rather + than an error. + """ + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response( + {"error": "job_id is required"}, + status=400 + ) + + cancelled = _cancel_job_by_id(job_id) + return web.json_response({"cancelled": cancelled}) + + @routes.post("/api/jobs/cancel") + async def cancel_jobs_batch(request): + """Cancel a batch of jobs by id. + + Body: {"job_ids": ["", ...]} + + Fail-fast: if any provided id is unknown (not running, pending, or + in history) the request returns 404 and no job is cancelled, so the + call has no partial side effects. + """ + try: + json_data = await request.json() + except json.JSONDecodeError: + return web.json_response( + {"error": "Request body must be valid JSON"}, + status=400 + ) + + job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None + if not isinstance(job_ids, list): + return web.json_response( + {"error": "job_ids must be a list"}, + status=400 + ) + + # Validate every id exists before cancelling anything. A snapshot of + # the queue + history is taken once so the membership check is + # consistent for the whole batch. + running, queued = self.prompt_queue.get_current_queue() + history = self.prompt_queue.get_history() + + unknown_ids = [ + jid for jid in job_ids + if classify_job_for_cancel(jid, running, queued, history) == CANCEL_UNKNOWN + ] + if unknown_ids: + return web.json_response( + {"error": "Job(s) not found", "unknown_ids": unknown_ids}, + status=404 + ) + + cancelled = False + for jid in job_ids: + if _cancel_job_by_id(jid): + cancelled = True + + return web.json_response({"cancelled": cancelled}) + @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) diff --git a/tests-unit/jobs_cancel_test/__init__.py b/tests-unit/jobs_cancel_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/jobs_cancel_test/jobs_cancel_test.py b/tests-unit/jobs_cancel_test/jobs_cancel_test.py new file mode 100644 index 000000000..ee88cefdd --- /dev/null +++ b/tests-unit/jobs_cancel_test/jobs_cancel_test.py @@ -0,0 +1,336 @@ +"""Tests for the jobs-namespace cancel endpoints. + +Covers both layers: + +* the pure cancel helpers in ``comfy_execution.jobs`` + (``classify_job_for_cancel`` / ``cancel_job``), which hold the business + logic of mapping a cancel onto interrupt-vs-dequeue, and + +* the HTTP contract of ``POST /api/jobs/{job_id}/cancel`` and + ``POST /api/jobs/cancel`` (status codes, single-cancel idempotency, and + batch fail-fast on an unknown id with no partial side effects). + +The HTTP layer is exercised against a small aiohttp app whose handlers are a +faithful copy of the wiring in ``server.py`` driven by a fake queue that +mirrors ``execution.PromptQueue`` (``get_current_queue`` / ``get_history`` / +``delete_queue_item``). This keeps the test free of the heavy ComfyUI runtime +(torch, nodes, ...) while still testing the real cancel logic. +""" + +import json + +import pytest +from aiohttp import web + +from comfy_execution.jobs import ( + CANCEL_PENDING, + CANCEL_RUNNING, + CANCEL_TERMINAL, + CANCEL_UNKNOWN, + cancel_job, + classify_job_for_cancel, +) + +# Classifications for which a cancel was actually dispatched (vs a no-op). +_CANCELLED = (CANCEL_RUNNING, CANCEL_PENDING) + + +def make_queue_item(prompt_id, number=0): + """Build a queue tuple shaped like the real ones: index 1 is the id.""" + return (number, prompt_id, {}, {}, []) + + +class FakePromptQueue: + """Minimal stand-in for execution.PromptQueue for the cancel paths. + + Tracks interrupts and dequeues so tests can assert side effects. + """ + + def __init__(self, running=None, pending=None, history=None): + self._running = list(running or []) + self._pending = list(pending or []) + self._history = dict(history or {}) + self.interrupt_count = 0 + + def get_current_queue(self): + return (list(self._running), list(self._pending)) + + def get_history(self, prompt_id=None): + if prompt_id is None: + return dict(self._history) + if prompt_id in self._history: + return {prompt_id: self._history[prompt_id]} + return {} + + def delete_queue_item(self, function): + for i, item in enumerate(self._pending): + if function(item): + self._pending.pop(i) + return True + return False + + def interrupt_processing(self): + self.interrupt_count += 1 + + +def build_app(queue): + """Build an aiohttp app exposing the cancel routes against ``queue``. + + Handler bodies mirror server.py exactly. + """ + + def _cancel_job_by_id(job_id): + running, pending = queue.get_current_queue() + history = queue.get_history() + + def interrupt(): + queue.interrupt_processing() + + def dequeue(prompt_id): + return queue.delete_queue_item(lambda a: a[1] == prompt_id) + + classification = cancel_job( + job_id, running, pending, history, interrupt, dequeue + ) + return classification in _CANCELLED + + async def cancel_job_by_id(request): + job_id = request.match_info.get("job_id", None) + if not job_id: + return web.json_response({"error": "job_id is required"}, status=400) + cancelled = _cancel_job_by_id(job_id) + return web.json_response({"cancelled": cancelled}) + + async def cancel_jobs_batch(request): + try: + json_data = await request.json() + except json.JSONDecodeError: + return web.json_response( + {"error": "Request body must be valid JSON"}, status=400 + ) + + job_ids = json_data.get("job_ids") if isinstance(json_data, dict) else None + if not isinstance(job_ids, list): + return web.json_response({"error": "job_ids must be a list"}, status=400) + + running, pending = queue.get_current_queue() + history = queue.get_history() + unknown_ids = [ + jid + for jid in job_ids + if classify_job_for_cancel(jid, running, pending, history) == CANCEL_UNKNOWN + ] + if unknown_ids: + return web.json_response( + {"error": "Job(s) not found", "unknown_ids": unknown_ids}, status=404 + ) + + cancelled = False + for jid in job_ids: + if _cancel_job_by_id(jid): + cancelled = True + return web.json_response({"cancelled": cancelled}) + + app = web.Application() + app.router.add_post("/api/jobs/{job_id}/cancel", cancel_job_by_id) + app.router.add_post("/api/jobs/cancel", cancel_jobs_batch) + return app + + +# --------------------------------------------------------------------------- +# Pure helper tests: classification + cancel side effects +# --------------------------------------------------------------------------- + + +class TestClassifyJobForCancel: + def test_running(self): + running = [make_queue_item("a")] + assert classify_job_for_cancel("a", running, [], {}) == CANCEL_RUNNING + + def test_pending(self): + pending = [make_queue_item("b")] + assert classify_job_for_cancel("b", [], pending, {}) == CANCEL_PENDING + + def test_terminal(self): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + assert classify_job_for_cancel("c", [], [], history) == CANCEL_TERMINAL + + def test_unknown(self): + assert classify_job_for_cancel("z", [], [], {}) == CANCEL_UNKNOWN + + +class TestCancelJobHelper: + def test_running_is_interrupted_not_dequeued(self): + running = [make_queue_item("a")] + interrupts = [] + dequeues = [] + result = cancel_job( + "a", running, [], {}, + interrupt=lambda: interrupts.append(True), + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_RUNNING + assert interrupts == [True] + assert dequeues == [] + + def test_pending_is_dequeued_not_interrupted(self): + pending = [make_queue_item("b")] + interrupts = [] + dequeues = [] + result = cancel_job( + "b", [], pending, {}, + interrupt=lambda: interrupts.append(True), + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_PENDING + assert dequeues == ["b"] + assert interrupts == [] + + def test_terminal_is_noop(self): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + interrupts = [] + dequeues = [] + result = cancel_job( + "c", [], [], history, + interrupt=lambda: interrupts.append(True), + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_TERMINAL + assert interrupts == [] + assert dequeues == [] + + def test_unknown_is_noop(self): + interrupts = [] + dequeues = [] + result = cancel_job( + "z", [], [], {}, + interrupt=lambda: interrupts.append(True), + dequeue=lambda pid: dequeues.append(pid) or True, + ) + assert result == CANCEL_UNKNOWN + assert interrupts == [] + assert dequeues == [] + + +# --------------------------------------------------------------------------- +# HTTP contract tests: POST /api/jobs/{job_id}/cancel +# --------------------------------------------------------------------------- + + +class TestSingleCancelEndpoint: + @pytest.mark.asyncio + async def test_cancel_running_job_interrupts(self, aiohttp_client): + queue = FakePromptQueue(running=[make_queue_item("a")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/a/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 + + @pytest.mark.asyncio + async def test_cancel_pending_job_dequeues(self, aiohttp_client): + queue = FakePromptQueue(pending=[make_queue_item("b")]) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/b/cancel") + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + # Pending job removed from the queue; nothing interrupted. + assert queue.get_current_queue()[1] == [] + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_terminal_job_is_idempotent_noop(self, aiohttp_client): + history = {"c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}} + queue = FakePromptQueue(history=history) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/c/cancel") + + # Already-finished job: 200 no-op (cancelled=false), not an error. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_cancel_unknown_id_is_200_noop(self, aiohttp_client): + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/does-not-exist/cancel") + + # Single-cancel of an unknown id is treated as an idempotent no-op. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + +# --------------------------------------------------------------------------- +# HTTP contract tests: POST /api/jobs/cancel (batch) +# --------------------------------------------------------------------------- + + +class TestBatchCancelEndpoint: + @pytest.mark.asyncio + async def test_batch_happy_path(self, aiohttp_client): + queue = FakePromptQueue( + running=[make_queue_item("a")], + pending=[make_queue_item("b", number=1)], + ) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": ["a", "b"]}) + + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": True} + assert queue.interrupt_count == 1 # running job interrupted + assert queue.get_current_queue()[1] == [] # pending job dequeued + + @pytest.mark.asyncio + async def test_batch_fail_fast_404_on_unknown_id_no_side_effects( + self, aiohttp_client + ): + queue = FakePromptQueue( + running=[make_queue_item("a")], + pending=[make_queue_item("b", number=1)], + ) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post( + "/api/jobs/cancel", json={"job_ids": ["a", "missing", "b"]} + ) + + assert resp.status == 404 + body = await resp.json() + assert body["unknown_ids"] == ["missing"] + # Fail-fast: nothing was cancelled — no partial side effects. + assert queue.interrupt_count == 0 + assert len(queue.get_current_queue()[1]) == 1 + + @pytest.mark.asyncio + async def test_batch_all_terminal_is_idempotent_noop(self, aiohttp_client): + history = { + "c": {"prompt": make_queue_item("c"), "outputs": {}, "status": {}}, + "d": {"prompt": make_queue_item("d"), "outputs": {}, "status": {}}, + } + queue = FakePromptQueue(history=history) + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={"job_ids": ["c", "d"]}) + + # All known but terminal: 200 with cancelled=false, nothing dispatched. + assert resp.status == 200 + assert (await resp.json()) == {"cancelled": False} + assert queue.interrupt_count == 0 + + @pytest.mark.asyncio + async def test_batch_missing_job_ids_is_400(self, aiohttp_client): + queue = FakePromptQueue() + client = await aiohttp_client(build_app(queue)) + + resp = await client.post("/api/jobs/cancel", json={}) + + assert resp.status == 400