diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..24dd1ffd0 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict: } +def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]: + """Extract the workflow id from a prompt's ``extra_data``. + + The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]`` + when a prompt is queued. Any value that is not a non-empty string is treated as + missing so callers can rely on the return being either ``None`` or a string. + """ + if not isinstance(extra_data, dict): + return None + extra_pnginfo = extra_data.get('extra_pnginfo') + if not isinstance(extra_pnginfo, dict): + return None + workflow = extra_pnginfo.get('workflow') + if not isinstance(workflow, dict): + return None + workflow_id = workflow.get('id') + if isinstance(workflow_id, str) and workflow_id: + return workflow_id + return None + + def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: """Extract create_time and workflow_id from extra_data. @@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str tuple: (create_time, workflow_id) """ create_time = extra_data.get('create_time') - extra_pnginfo = extra_data.get('extra_pnginfo', {}) - workflow_id = extra_pnginfo.get('workflow', {}).get('id') + workflow_id = extract_workflow_id(extra_data) return create_time, workflow_id diff --git a/comfy_execution/metadata.py b/comfy_execution/metadata.py new file mode 100644 index 000000000..583e2bdb8 --- /dev/null +++ b/comfy_execution/metadata.py @@ -0,0 +1,86 @@ +"""Per-prompt metadata propagated alongside execution WebSocket events. + +The execution layer (``execution.py``) is intentionally kept agnostic of +workflow-level concepts. ``PromptServer`` registers metadata at submission time +and merges it onto outgoing WebSocket payloads in ``send_sync``. Today only +``workflow_id`` is propagated; the structure is a ``TypedDict`` so additional +keys can be added without churn at the call sites. + +Identity model — registry is keyed by an internal monotonic token, NOT by +``prompt_id``. ``post_prompt`` accepts a client-supplied ``prompt_id`` +verbatim, so two prompts can be queued with the same id and a registry keyed +only by ``prompt_id`` would misattribute events when queue order differs from +registration order. ``main.py``'s queue worker pins the active token on the +server for the duration of one prompt's execution and ``send_sync`` reads that +token — so each prompt's events get its own registered metadata regardless of +``prompt_id`` collisions. +""" + +import threading +from typing import Optional, TypedDict + +from comfy_execution.jobs import extract_workflow_id + + +PROMPT_METADATA_TOKEN_KEY = "__prompt_metadata_token" + + +class PromptMetadata(TypedDict, total=False): + workflow_id: Optional[str] + + +def build_prompt_metadata(extra_data: Optional[dict]) -> PromptMetadata: + """Build a ``PromptMetadata`` snapshot from a prompt's ``extra_data``. + + Returns an empty dict when no recognized metadata is present so callers can + skip registering anything in that case. + """ + meta: PromptMetadata = {} + workflow_id = extract_workflow_id(extra_data) + if workflow_id is not None: + meta["workflow_id"] = workflow_id + return meta + + +def resolve_progress_text_sid(sid, default_sid): + """Pick the recipient for a ``send_progress_text`` binary frame. + + Returns ``default_sid`` (typically ``PromptServer.client_id`` — the client + that submitted the active prompt) when the caller didn't pin a specific + socket. This narrows the audience for text status updates from "every + connected client" to "the client running this prompt", matching the + cross-client isolation other execution events already have. + + Splitting this out keeps the unit test independent of the full ``server`` + import chain. + """ + return default_sid if sid is None else sid + + +def merge_prompt_metadata( + registry: dict, + lock: threading.Lock, + active_token: Optional[int], + data, +): + """Return ``data`` with the metadata for the currently-active token merged + top-level. The event payload wins on conflict, and non-dict payloads (e.g. + the binary preview tuple) pass through untouched. + + The active token is set by ``main.py``'s queue worker around each + ``e.execute(...)``. The merge happens only when a token is currently active + *and* the payload carries a ``prompt_id`` — using ``prompt_id`` purely as a + marker that this is an execution event meant to receive metadata, not as + the lookup key. This makes the merge immune to ``prompt_id`` collisions. + """ + if not isinstance(data, dict): + return data + if not data.get("prompt_id"): + return data + if active_token is None: + return data + with lock: + meta = registry.get(active_token) + if not meta: + return data + return {**meta, **data} diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a3350..a4d8c56b6 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -159,11 +159,19 @@ class WebUIProgressHandler(ProgressHandler): def set_registry(self, registry: "ProgressRegistry"): self.registry = registry + def _lookup_workflow_id(self, prompt_id: str) -> Optional[str]: + get_meta = getattr(self.server_instance, "get_active_prompt_metadata", None) + if get_meta is None: + return None + return get_meta().get("workflow_id") + def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): """Send the current progress state to the client""" if self.server_instance is None: return + workflow_id = self._lookup_workflow_id(prompt_id) + # Only send info for non-pending nodes active_nodes = { node_id: { @@ -172,6 +180,7 @@ class WebUIProgressHandler(ProgressHandler): "state": state["state"].value, "node_id": node_id, "prompt_id": prompt_id, + "workflow_id": workflow_id, "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), @@ -181,7 +190,10 @@ class WebUIProgressHandler(ProgressHandler): } # Send a combined progress_state message with all node states - # Include client_id to ensure message is only sent to the initiating client + # Include client_id to ensure message is only sent to the initiating client. + # The outer ``workflow_id`` is merged in by ``PromptServer.send_sync`` via + # the per-prompt metadata registry; the nested copy on each node entry + # mirrors the wire shape consumed by the frontend. self.server_instance.send_sync( "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id ) @@ -215,6 +227,7 @@ class WebUIProgressHandler(ProgressHandler): metadata = { "node_id": node_id, "prompt_id": prompt_id, + "workflow_id": self._lookup_workflow_id(prompt_id), "display_node_id": self.registry.dynprompt.get_display_node_id( node_id ), diff --git a/execution.py b/execution.py index 4c7de2e84..00a81e335 100644 --- a/execution.py +++ b/execution.py @@ -36,6 +36,7 @@ from comfy_execution.graph import ( get_input_info, ) from comfy_execution.graph_utils import GraphBuilder, is_link +from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext @@ -1294,20 +1295,32 @@ class PromptQueue: with self.mutex: return len(self.queue) + len(self.currently_running) + def _extract_metadata_token(self, item): + # Queue item shape: (number, prompt_id, prompt, extra_data, outputs, sensitive) + extra_data = item[3] if len(item) > 3 and isinstance(item[3], dict) else None + if not extra_data: + return None + return extra_data.get(PROMPT_METADATA_TOKEN_KEY) + def wipe_queue(self): with self.mutex: + cancelled_tokens = [self._extract_metadata_token(item) for item in self.queue] self.queue = [] self.server.queue_updated() + for token in cancelled_tokens: + self.server.unregister_prompt_metadata(token) def delete_queue_item(self, function): with self.mutex: for x in range(len(self.queue)): if function(self.queue[x]): + cancelled_token = self._extract_metadata_token(self.queue[x]) if len(self.queue) == 1: self.wipe_queue() else: self.queue.pop(x) heapq.heapify(self.queue) + self.server.unregister_prompt_metadata(cancelled_token) self.server.queue_updated() return True return False diff --git a/main.py b/main.py index a6fdaf43c..5ccd4daae 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ from utils.mime_types import init_mime_types import faulthandler import logging import sys +from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags @@ -317,20 +318,34 @@ def prompt_worker(q, server_instance): for k in sensitive: extra_data[k] = sensitive[k] + # Pin the metadata token registered for this exact prompt while it + # runs so ``send_sync`` can decorate its frames with the right + # ``workflow_id`` even if another submission shares the same + # ``prompt_id``. + metadata_token = extra_data.pop(PROMPT_METADATA_TOKEN_KEY, None) + server_instance.active_prompt_metadata_token = metadata_token + asset_seeder.pause() - e.execute(item[2], prompt_id, extra_data, item[4]) + try: + e.execute(item[2], prompt_id, extra_data, item[4]) - need_gc = True + need_gc = True - remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] - q.task_done(item_id, - e.history_result, - status=execution.PromptQueue.ExecutionStatus( - status_str='success' if e.success else 'error', - completed=e.success, - messages=e.status_messages), process_item=remove_sensitive) - if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] + q.task_done(item_id, + e.history_result, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages), process_item=remove_sensitive) + if server_instance.client_id is not None: + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) + finally: + # Drop the per-prompt metadata only AFTER the terminal "executing" + # send so the registered workflow_id is merged onto that frame. + # This is what eliminates the #13684 finally-clear race. + server_instance.active_prompt_metadata_token = None + server_instance.unregister_prompt_metadata(metadata_token) current_time = time.perf_counter() execution_time = current_time - execution_start_time diff --git a/server.py b/server.py index 44470b904..8e55990bb 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,15 @@ import time import nodes import folder_paths import execution +import threading from comfy_execution.jobs import JobStatus, get_job, get_all_jobs +from comfy_execution.metadata import ( + PROMPT_METADATA_TOKEN_KEY, + PromptMetadata, + build_prompt_metadata, + merge_prompt_metadata, + resolve_progress_text_sid, +) import uuid import urllib import json @@ -252,6 +260,15 @@ class PromptServer(): self.last_node_id = None self.client_id = None + # Keyed by an internal monotonic token rather than by ``prompt_id`` + # because clients can supply the same ``prompt_id`` on retry/dedupe. + self.prompt_metadata: dict[int, PromptMetadata] = {} + self._prompt_metadata_lock = threading.Lock() + self._prompt_metadata_counter = 0 + # Set by ``main.py``'s queue worker for the duration of one prompt's + # execution; read by ``send_sync`` to merge the right metadata. + self.active_prompt_metadata_token: Optional[int] = None + self.on_prompt_handlers = [] @routes.get('/ws') @@ -275,7 +292,12 @@ class PromptServer(): await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid) # On reconnect if we are the currently executing client send the current node if self.client_id == sid and self.last_node_id is not None: - await self.send("executing", { "node": self.last_node_id }, sid) + last_prompt_id = getattr(self, "last_prompt_id", None) + payload: dict = {"node": self.last_node_id} + if last_prompt_id: + payload["prompt_id"] = last_prompt_id + payload.update(self.get_active_prompt_metadata()) + await self.send("executing", payload, sid) # Flag to track if we've received the first message first_message = True @@ -955,6 +977,9 @@ class PromptServer(): if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds + token = self.register_prompt_metadata(extra_data) + if token is not None: + extra_data[PROMPT_METADATA_TOKEN_KEY] = token self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) @@ -1216,7 +1241,50 @@ class PromptServer(): elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) + def register_prompt_metadata(self, extra_data: dict) -> Optional[int]: + """Capture per-prompt metadata at submission time and return a token + identifying this registration. + + Returns ``None`` when there is no recognized metadata, signalling that + no token needs to be threaded through the queue. Otherwise the token + must be stored on the queue item (typically via + :data:`PROMPT_METADATA_TOKEN_KEY` in ``extra_data``) and pinned on the + server as ``active_prompt_metadata_token`` while the prompt runs, so + the merge in ``send_sync`` picks up this prompt's metadata even when + another prompt is registered under the same ``prompt_id``. + """ + meta = build_prompt_metadata(extra_data) + if not meta: + return None + with self._prompt_metadata_lock: + self._prompt_metadata_counter += 1 + token = self._prompt_metadata_counter + self.prompt_metadata[token] = meta + return token + + def unregister_prompt_metadata(self, token: Optional[int]) -> None: + if token is None: + return + with self._prompt_metadata_lock: + self.prompt_metadata.pop(token, None) + + def get_prompt_metadata_by_token(self, token: Optional[int]) -> PromptMetadata: + if token is None: + return {} + with self._prompt_metadata_lock: + return dict(self.prompt_metadata.get(token, {})) + + def get_active_prompt_metadata(self) -> PromptMetadata: + """Snapshot of the metadata for the currently-executing prompt.""" + return self.get_prompt_metadata_by_token(self.active_prompt_metadata_token) + def send_sync(self, event, data, sid=None): + data = merge_prompt_metadata( + self.prompt_metadata, + self._prompt_metadata_lock, + self.active_prompt_metadata_token, + data, + ) self.loop.call_soon_threadsafe( self.messages.put_nowait, (event, data, sid)) @@ -1285,7 +1353,10 @@ class PromptServer(): return json_data def send_progress_text( - self, text: Union[bytes, bytearray, str], node_id: str, sid=None + self, + text: Union[bytes, bytearray, str], + node_id: str, + sid=None, ): if isinstance(text, str): text = text.encode("utf-8") @@ -1294,4 +1365,11 @@ class PromptServer(): # Pack the node_id length as a 4-byte unsigned integer, followed by the node_id bytes message = struct.pack(">I", len(node_id_bytes)) + node_id_bytes + text + # Default routing to the active prompt's client so other clients don't + # silently receive untagged text frames. The binary wire format does + # not yet carry prompt_id/workflow_id, so cross-tab filtering inside a + # single client still depends on a follow-up wire-format change with a + # feature flag. + sid = resolve_progress_text_sid(sid, self.client_id) + self.send_sync(BinaryEventTypes.TEXT, message, sid) diff --git a/tests-unit/server_test/test_prompt_metadata.py b/tests-unit/server_test/test_prompt_metadata.py new file mode 100644 index 000000000..1f9b9dc9b --- /dev/null +++ b/tests-unit/server_test/test_prompt_metadata.py @@ -0,0 +1,182 @@ +"""Tests for the per-prompt metadata registry used to propagate ``workflow_id`` +through WebSocket events without coupling ``execution.py`` to workflow-level +concepts. + +The registry is keyed by an internal monotonic token (NOT by ``prompt_id``) +because ``post_prompt`` accepts a client-supplied ``prompt_id`` verbatim and +two prompts can share an id. ``main.py``'s queue worker pins the active token +on the server around each ``e.execute(...)`` and the merge in ``send_sync`` +reads that pinned token, so each prompt's events get its own metadata +regardless of ``prompt_id`` collisions or queue-vs-stack ordering. +""" + +import threading + +import pytest + +from comfy_execution.metadata import ( + build_prompt_metadata, + merge_prompt_metadata, + resolve_progress_text_sid, +) + + +@pytest.fixture +def registry(): + return {} + + +@pytest.fixture +def lock(): + return threading.Lock() + + +class TestBuildPromptMetadata: + def test_returns_workflow_id_when_present(self): + extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-1"}}} + assert build_prompt_metadata(extra_data) == {"workflow_id": "wf-1"} + + def test_empty_when_workflow_id_missing(self): + assert build_prompt_metadata({}) == {} + assert build_prompt_metadata({"extra_pnginfo": {}}) == {} + assert build_prompt_metadata({"extra_pnginfo": {"workflow": {}}}) == {} + + def test_empty_when_workflow_id_not_a_non_empty_string(self): + assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": ""}}}) == {} + assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": 42}}}) == {} + assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": None}}}) == {} + + def test_empty_on_non_dict_input(self): + assert build_prompt_metadata(None) == {} + assert build_prompt_metadata("not a dict") == {} + + +class TestMergeMetadata: + """``merge_prompt_metadata`` decorates execution events with the metadata + for the currently-active token. Event payload fields win on conflict, + binary payloads pass through, and the merge is gated on a ``prompt_id`` + marker to avoid decorating server-status events like ``status`` / + ``queue_updated``.""" + + def test_merges_for_active_token_when_payload_has_prompt_id(self, registry, lock): + registry[42] = {"workflow_id": "wf-1"} + merged = merge_prompt_metadata(registry, lock, 42, {"node": "n1", "prompt_id": "p1"}) + assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"} + + def test_passthrough_when_no_active_token(self, registry, lock): + registry[42] = {"workflow_id": "wf-1"} + merged = merge_prompt_metadata(registry, lock, None, {"node": "n1", "prompt_id": "p1"}) + assert merged == {"node": "n1", "prompt_id": "p1"} + + def test_passthrough_when_active_token_unknown(self, registry, lock): + # Token was unregistered already (or never registered) — merge is a no-op. + merged = merge_prompt_metadata(registry, lock, 99, {"prompt_id": "p1"}) + assert merged == {"prompt_id": "p1"} + + def test_passthrough_when_no_prompt_id(self, registry, lock): + # Server-status frames (status, queue_updated, etc.) carry no prompt_id + # and must not be decorated. + registry[42] = {"workflow_id": "wf-1"} + merged = merge_prompt_metadata(registry, lock, 42, {"status": {"queue_remaining": 0}}) + assert merged == {"status": {"queue_remaining": 0}} + + def test_passthrough_for_non_dict_payload(self, registry, lock): + registry[42] = {"workflow_id": "wf-1"} + binary = (b"image-bytes", {"prompt_id": "p1"}) + assert merge_prompt_metadata(registry, lock, 42, binary) is binary + + def test_event_payload_wins_over_registered_metadata(self, registry, lock): + registry[42] = {"workflow_id": "wf-registered"} + merged = merge_prompt_metadata( + registry, lock, 42, {"prompt_id": "p1", "workflow_id": "wf-caller"} + ) + assert merged["workflow_id"] == "wf-caller" + + +class TestProgressTextSidResolution: + """``BinaryEventTypes.TEXT`` frames don't yet carry ``prompt_id`` / + ``workflow_id`` in their wire shape, so cross-client routing has to happen + at the ``sid`` level. The default sid pins the broadcast to the active + prompt's client. + """ + + def test_explicit_sid_passes_through(self): + assert resolve_progress_text_sid("client-explicit", "client-active") == "client-explicit" + + def test_none_sid_defaults_to_active_client(self): + assert resolve_progress_text_sid(None, "client-active") == "client-active" + + def test_none_sid_with_no_active_client_stays_none(self): + assert resolve_progress_text_sid(None, None) is None + + +class TestPromptIdCollisionWithTokens: + """Two prompts can be queued with the same client-supplied ``prompt_id``. + With a registry keyed by ``prompt_id`` the second registration would + overwrite the first or be erased by the first's unregister. The token model + makes each registration independent.""" + + def test_two_submissions_get_distinct_tokens_and_each_merges_correctly(self, registry, lock): + # Two submissions of the same prompt_id with different workflow_ids. + registry[1] = {"workflow_id": "wf-A"} # token from submission #1 + registry[2] = {"workflow_id": "wf-B"} # token from submission #2 + + # Worker is currently running submission #1. + merged = merge_prompt_metadata(registry, lock, 1, {"prompt_id": "P", "node": "x"}) + assert merged["workflow_id"] == "wf-A" + + # Worker switches to submission #2 (queue ordering, retry, whatever). + merged = merge_prompt_metadata(registry, lock, 2, {"prompt_id": "P", "node": "y"}) + assert merged["workflow_id"] == "wf-B" + + def test_unregister_by_token_does_not_drop_concurrent_submission(self, registry, lock): + registry[1] = {"workflow_id": "wf-A"} + registry[2] = {"workflow_id": "wf-B"} + + # Submission #1 finishes — drop its token only. + registry.pop(1, None) + + # Submission #2 still has its metadata. + merged = merge_prompt_metadata(registry, lock, 2, {"prompt_id": "P"}) + assert merged["workflow_id"] == "wf-B" + + def test_execution_order_independent_of_registration_order(self, registry, lock): + """Regression for the LIFO-stack failure mode: queue executes #1 first + but the previous stack design would have made the merge pick the + latest-registered (#2) metadata. Token model is immune.""" + registry[1] = {"workflow_id": "wf-A"} + registry[2] = {"workflow_id": "wf-B"} + + # Even though #2 was registered after #1, executing #1 still sees wf-A. + merged_first_running = merge_prompt_metadata(registry, lock, 1, {"prompt_id": "P"}) + assert merged_first_running["workflow_id"] == "wf-A" + + +class TestRaceRegressionForTerminalExecutingFrame: + """Regression for the PR #13684 finally-clear race. + + Executor's ``finally`` previously cleared the workflow_id source, so the + post-completion terminal frame shipped ``workflow_id=None``. With the + token model, the active token stays pinned until ``main.py`` clears it + *after* the terminal send. + """ + + def test_terminal_executing_frame_includes_workflow_id(self, registry, lock): + registry[7] = {"workflow_id": "wf-1"} + active_token = 7 + + # main.py emits the terminal frame BEFORE clearing the active token. + terminal = merge_prompt_metadata( + registry, lock, active_token, {"node": None, "prompt_id": "p1"} + ) + # main.py's finally: clear active token + unregister. + active_token = None + registry.pop(7, None) + + assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"} + + # After cleanup, straggler events get no metadata. + straggler = merge_prompt_metadata( + registry, lock, active_token, {"node": None, "prompt_id": "p1"} + ) + assert "workflow_id" not in straggler diff --git a/tests-unit/server_test/test_prompt_metadata_e2e.py b/tests-unit/server_test/test_prompt_metadata_e2e.py new file mode 100644 index 000000000..c76688ac2 --- /dev/null +++ b/tests-unit/server_test/test_prompt_metadata_e2e.py @@ -0,0 +1,163 @@ +"""End-to-end checks of the prompt-metadata propagation. + +These tests instantiate the real ``PromptServer`` (no heavy mocking) and drive +the same call sequence ``main.py`` would: register metadata, pin the token, +emit a bunch of execution-shaped frames through ``send_sync``, then clear the +token + unregister. We then drain the server's message queue and inspect the +exact frames that would have been written to the WebSocket — proving the wire +shape and the ``prompt_id`` collision invariants without needing a browser. +""" + +import asyncio +import os +import sys + +import pytest + +# Ensure the repo root is the first thing on sys.path so ``utils`` resolves to +# the ComfyUI package, not a site-packages namespace package. ``utils`` must +# load before ``app.frontend_management`` imports it. +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) +import utils # noqa: F401 -- pin the package before later imports + +from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY + +server_module = pytest.importorskip("server") + + +@pytest.fixture +def prompt_server(): + loop = asyncio.new_event_loop() + try: + server = server_module.PromptServer(loop) + yield server + finally: + loop.close() + + +def _drain(server) -> list: + """Empty the server's internal asyncio.Queue and return every (event, data, sid).""" + frames = [] + queue = server.messages + while not queue.empty(): + frames.append(queue.get_nowait()) + return frames + + +def _run_pending_callbacks(server): + """``send_sync`` schedules ``messages.put_nowait`` via ``loop.call_soon_threadsafe``. + Pump the loop a single iteration so those callbacks land before we drain.""" + server.loop.call_soon(server.loop.stop) + server.loop.run_forever() + + +def _simulate_prompt_lifecycle(server, extra_data: dict, prompt_id: str): + """Replicates what ``main.py``'s queue worker does around ``e.execute(...)`` + plus a representative sample of the events ``execution.py`` emits.""" + token = server.register_prompt_metadata(extra_data) + if token is not None: + extra_data[PROMPT_METADATA_TOKEN_KEY] = token + + # Worker picks up the item, pins the token. + server.active_prompt_metadata_token = extra_data.pop(PROMPT_METADATA_TOKEN_KEY, None) + client_id = "client-A" + server.client_id = client_id + + # A representative sample of the eight execution events listed in PR #13684. + server.send_sync("execution_start", {"prompt_id": prompt_id}, client_id) + server.send_sync("execution_cached", {"nodes": [], "prompt_id": prompt_id}, client_id) + server.send_sync("executing", {"node": "n1", "display_node": "n1", "prompt_id": prompt_id}, client_id) + server.send_sync( + "executed", + {"node": "n1", "display_node": "n1", "output": {}, "prompt_id": prompt_id}, + client_id, + ) + server.send_sync("execution_success", {"prompt_id": prompt_id}, client_id) + # The terminal frame ``main.py`` itself emits, just before clearing state. + server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, client_id) + + # Worker finally clause: clear token then unregister. + server.active_prompt_metadata_token = None + server.unregister_prompt_metadata(token) + + +class TestEndToEndWorkflowIdOnAllExecutionEvents: + def test_every_execution_event_carries_workflow_id(self, prompt_server): + _simulate_prompt_lifecycle( + prompt_server, + extra_data={"extra_pnginfo": {"workflow": {"id": "wf-xyz"}}}, + prompt_id="p-1", + ) + _run_pending_callbacks(prompt_server) + frames = _drain(prompt_server) + + assert frames, "no frames emitted" + # Every frame with a prompt_id payload must carry workflow_id top-level. + for event, data, _sid in frames: + if isinstance(data, dict) and data.get("prompt_id"): + assert data.get("workflow_id") == "wf-xyz", (event, data) + + # And specifically verify the terminal frame — the #13684 race victim. + terminal = [ + (e, d) for e, d, _ in frames + if e == "executing" and isinstance(d, dict) and d.get("node") is None + ] + assert terminal, "no terminal executing frame emitted" + _, terminal_payload = terminal[-1] + assert terminal_payload["workflow_id"] == "wf-xyz" + + def test_status_frame_is_not_decorated(self, prompt_server): + prompt_server.register_prompt_metadata({"extra_pnginfo": {"workflow": {"id": "wf-1"}}}) + prompt_server.active_prompt_metadata_token = 1 + + # A status frame has no prompt_id and must remain untouched. + prompt_server.send_sync("status", {"status": {"queue_remaining": 0}}, None) + _run_pending_callbacks(prompt_server) + frames = _drain(prompt_server) + + assert any(e == "status" for e, _, _ in frames) + for event, data, _ in frames: + if event == "status": + assert "workflow_id" not in data + + +class TestEndToEndPromptIdCollision: + """Drive two prompts with the same client-supplied ``prompt_id`` but + different ``workflow_id`` values. The token model must keep their event + streams attributed correctly.""" + + def test_same_prompt_id_two_workflows_each_stream_attributed_correctly(self, prompt_server): + _simulate_prompt_lifecycle( + prompt_server, + extra_data={"extra_pnginfo": {"workflow": {"id": "wf-A"}}}, + prompt_id="P-shared", + ) + _simulate_prompt_lifecycle( + prompt_server, + extra_data={"extra_pnginfo": {"workflow": {"id": "wf-B"}}}, + prompt_id="P-shared", + ) + _run_pending_callbacks(prompt_server) + frames = _drain(prompt_server) + + # Partition the frames by lifecycle (execution_start..terminal-executing). + runs: list[list[dict]] = [[]] + for event, data, _ in frames: + if not isinstance(data, dict): + continue + runs[-1].append({"event": event, **data}) + if event == "executing" and data.get("node") is None: + runs.append([]) + # The trailing empty bucket from the final split is fine. + runs = [r for r in runs if r] + assert len(runs) == 2, runs + + run_a, run_b = runs + for entry in run_a: + if entry.get("prompt_id"): + assert entry.get("workflow_id") == "wf-A", entry + for entry in run_b: + if entry.get("prompt_id"): + assert entry.get("workflow_id") == "wf-B", entry