diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..3fbcc3eb0 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -3,6 +3,7 @@ Job utilities for the /api/jobs endpoint. Provides normalization and helper functions for job status tracking. """ +import uuid from typing import Optional from comfy_api.internal import prune_dict @@ -19,6 +20,26 @@ class JobStatus: ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] +def validate_job_id(value) -> str: + """Validate a client-supplied job (prompt) id. + + Job ids must be UUIDs in the canonical lowercase hyphenated form. The id + is stored and compared verbatim everywhere downstream — history keys, + websocket events, /interrupt matching, and the assets ``job_ids`` filter + (a String(36) column matched exactly) — so accepting another spelling + would either rewrite the client's id behind its back or mint a job whose + outputs the filter can never find. Rejecting loudly beats both. + + Returns the id unchanged. Raises ValueError when the value is not a + string in canonical UUID form. + """ + if not isinstance(value, str): + raise ValueError(f"job id must be a string, got {type(value).__name__}") + if str(uuid.UUID(value)) != value: + raise ValueError("job id must be a UUID in canonical lowercase hyphenated form") + return value + + # Media types that can be previewed in the frontend PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) diff --git a/openapi.yaml b/openapi.yaml index c27ed7adf..58614103a 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -896,6 +896,11 @@ components: additionalProperties: true description: The workflow graph to execute type: object + prompt_id: + description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one. + format: uuid + nullable: true + type: string workflow_id: description: UUID identifying the cloud workflow entity to associate with this job type: string diff --git a/server.py b/server.py index a85c1e591..cc3b33a5c 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import time import nodes import folder_paths import execution -from comfy_execution.jobs import JobStatus, get_job, get_all_jobs +from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id import uuid import urllib import json @@ -942,7 +942,21 @@ class PromptServer(): if "prompt" in json_data: prompt = json_data["prompt"] - prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) + client_prompt_id = json_data.get("prompt_id") + if client_prompt_id is None: + # Absent or explicit null: the server mints the id. + prompt_id = str(uuid.uuid4()) + else: + try: + prompt_id = validate_job_id(client_prompt_id) + except ValueError: + error = { + "type": "invalid_prompt_id", + "message": "prompt_id must be a valid UUID", + "details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one", + "extra_info": {} + } + return web.json_response({"error": error, "node_errors": {}}, status=400) partial_execution_targets = None if "partial_execution_targets" in json_data: diff --git a/tests-unit/assets_test/test_prompt_id_enforcement.py b/tests-unit/assets_test/test_prompt_id_enforcement.py new file mode 100644 index 000000000..fb961beae --- /dev/null +++ b/tests-unit/assets_test/test_prompt_id_enforcement.py @@ -0,0 +1,69 @@ +"""POST /prompt enforces canonical-UUID job ids at creation time. + +Lives in assets_test because it uses this suite's booted-server fixture and +because the invariant exists for the assets pipeline: the GET /api/assets +``job_ids`` filter matches stored job ids exactly, so a job minted with a +non-canonical id would produce assets the filter can never find. + +The prompt bodies here are intentionally invalid workflows — prompt_id +validation happens before workflow validation, so a rejected id returns +``invalid_prompt_id`` while an accepted id falls through to the ordinary +workflow-validation error (proving it cleared the id check). +""" +import requests + + +def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response: + return http.post(api_base + "/prompt", json=body, timeout=30) + + +def _error_type(r: requests.Response) -> str: + return r.json()["error"]["type"] + + +def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str): + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"}) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str): + # Previously str()-coerced (123 became the job id "123"); must now be a 400, + # not a 500 from uuid.UUID choking on a non-string. + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123}) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str): + # Parseable as a UUID, but not the canonical lowercase form: rejected + # loudly rather than silently rewritten (downstream lookups match the + # stored id exactly). + r = _post_prompt( + http, + api_base, + {"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"}, + ) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_canonical_uuid_accepted(http: requests.Session, api_base: str): + # The id clears validation; the empty workflow then fails ordinary prompt + # validation, proving the request got past the id check. + r = _post_prompt( + http, + api_base, + {"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"}, + ) + assert r.status_code == 400, r.text + assert _error_type(r) != "invalid_prompt_id" + + +def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str): + # Explicit null means "server generates" and must not be rejected as an + # invalid id. (The minted id itself is not observable here because the + # workflow is invalid; unit tests cover validate_job_id directly.) + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None}) + assert r.status_code == 400, r.text + assert _error_type(r) != "invalid_prompt_id" diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 814af5c13..30e47071d 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -1,5 +1,7 @@ """Unit tests for comfy_execution/jobs.py""" +import pytest + from comfy_execution.jobs import ( JobStatus, is_previewable, @@ -10,9 +12,50 @@ from comfy_execution.jobs import ( get_outputs_summary, apply_sorting, has_3d_extension, + validate_job_id, ) +class TestValidateJobId: + """validate_job_id guards job creation: POST /prompt rejects ids it raises on.""" + + def test_canonical_form_passes_through(self): + cid = "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7" + assert validate_job_id(cid) == cid + + @pytest.mark.parametrize( + "variant", + [ + "A1B2C3D4-E5F6-7A89-B0C1-D2E3F4A5B6C7", # uppercase + "{a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7}", # braced + "urn:uuid:a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", # URN + "a1b2c3d4e5f67a89b0c1d2e3f4a5b6c7", # bare hex + " a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7 ", # padded + ], + ) + def test_non_canonical_spellings_rejected(self, variant): + # uuid.UUID parses all of these, but accepting them would silently + # rewrite the client's id (history keys, websocket events, and the + # assets job_ids filter all match the stored form exactly). + with pytest.raises(ValueError): + validate_job_id(variant) + + @pytest.mark.parametrize( + "bad", + ["", "not-a-uuid", "prompt-123", "a1b2c3d4-e5f6-7a89-b0c1", "None"], + ) + def test_non_uuid_strings_rejected(self, bad): + with pytest.raises(ValueError): + validate_job_id(bad) + + @pytest.mark.parametrize("bad", [123, 1.5, True, None, ["a"], {"id": "x"}]) + def test_non_strings_rejected(self, bad): + # uuid.UUID raises AttributeError/TypeError on non-strings; the helper + # must normalize those to ValueError so callers need one except clause. + with pytest.raises(ValueError): + validate_job_id(bad) + + class TestJobStatus: """Test JobStatus constants."""