Key metadata by token, not prompt_id, to survive id collisions

Adversarial review caught that a LIFO stack keyed by ``prompt_id`` still
mis-attributes events when queue execution order differs from registration
order: a second submission with the same ``prompt_id`` lands on top of the
stack, so the first prompt's events read the wrong workflow_id while it
runs, and the first's ``unregister`` then pops the second prompt's entry.

Replace the stack with an internal monotonic token. ``post_prompt``
registers metadata and stashes the returned token on ``extra_data`` under
``PROMPT_METADATA_TOKEN_KEY``. ``main.py``'s queue worker pulls the token
out, pins it on ``PromptServer.active_prompt_metadata_token`` for the
prompt's execution, and clears + unregisters in ``finally``. The merge in
``send_sync`` reads the active token, so each prompt's events are merged
with its own metadata regardless of ``prompt_id`` collisions.

- comfy_execution/metadata.py: ``merge_prompt_metadata`` now takes an
  active token; registry is ``dict[int, PromptMetadata]``; new
  ``PROMPT_METADATA_TOKEN_KEY`` constant for the extra_data carrier.
- server.py: ``register_prompt_metadata`` returns a token (or ``None``
  when no metadata applies); ``unregister`` takes a token;
  ``get_active_prompt_metadata`` snapshots the pinned entry.
- main.py: pops the token from extra_data, pins on the server, clears
  after the terminal "executing: {node: None}" send.
- execution.py ``PromptQueue``: wipe_queue / delete_queue_item now
  unregister by token extracted from each item's extra_data.
- comfy_execution/progress.py: reads workflow_id via
  ``get_active_prompt_metadata`` rather than per-prompt_id lookup.
- tests: unit tests updated for the token signature, plus a real E2E
  test (test_prompt_metadata_e2e.py) that instantiates the actual
  PromptServer and verifies same-prompt_id-different-workflow_id
  submissions don't cross-attribute.

Verified end-to-end against a live ComfyUI server: two submissions with
identical client-supplied prompt_id but different workflow_id each emit
their full execution event stream (execution_start, execution_cached,
executing, executed, execution_success, progress_state, terminal executing)
with the correct workflow_id top-level. 68 / 68 tests pass.
This commit is contained in:
dante01yoon 2026-05-19 22:19:15 +09:00
parent b0c05af67f
commit 85a12d0a83
7 changed files with 343 additions and 119 deletions

View File

@ -1,10 +1,19 @@
"""Per-prompt metadata propagated alongside execution WebSocket events. """Per-prompt metadata propagated alongside execution WebSocket events.
The execution layer (``execution.py``) is intentionally kept agnostic of The execution layer (``execution.py``) is intentionally kept agnostic of
workflow-level concepts. Instead, ``PromptServer`` registers per-``prompt_id`` workflow-level concepts. ``PromptServer`` registers metadata at submission time
metadata at submission time and merges it onto outgoing WebSocket payloads in and merges it onto outgoing WebSocket payloads in ``send_sync``. Today only
``send_sync``. Today only ``workflow_id`` is propagated; the structure is a ``workflow_id`` is propagated; the structure is a ``TypedDict`` so additional
``TypedDict`` so additional keys can be added without churn at the call sites. keys can be added without churn at the call sites.
Identity model registry is keyed by an internal monotonic token, NOT by
``prompt_id``. ``post_prompt`` accepts a client-supplied ``prompt_id``
verbatim, so two prompts can be queued with the same id and a registry keyed
only by ``prompt_id`` would misattribute events when queue order differs from
registration order. ``main.py``'s queue worker pins the active token on the
server for the duration of one prompt's execution and ``send_sync`` reads that
token so each prompt's events get its own registered metadata regardless of
``prompt_id`` collisions.
""" """
import threading import threading
@ -13,6 +22,9 @@ from typing import Optional, TypedDict
from comfy_execution.jobs import extract_workflow_id from comfy_execution.jobs import extract_workflow_id
PROMPT_METADATA_TOKEN_KEY = "__prompt_metadata_token"
class PromptMetadata(TypedDict, total=False): class PromptMetadata(TypedDict, total=False):
workflow_id: Optional[str] workflow_id: Optional[str]
@ -48,24 +60,27 @@ def resolve_progress_text_sid(sid, default_sid):
def merge_prompt_metadata( def merge_prompt_metadata(
registry: dict, registry: dict,
lock: threading.Lock, lock: threading.Lock,
active_token: Optional[int],
data, data,
): ):
"""Return ``data`` with the registered metadata for its ``prompt_id`` merged """Return ``data`` with the metadata for the currently-active token merged
top-level. The event payload wins on conflict, and non-dict payloads (e.g. top-level. The event payload wins on conflict, and non-dict payloads (e.g.
the binary preview tuple) pass through untouched. the binary preview tuple) pass through untouched.
The registry is a stack per ``prompt_id`` (``dict[str, list[PromptMetadata]]``) The active token is set by ``main.py``'s queue worker around each
so duplicate submissions of the same ``prompt_id`` don't clobber each ``e.execute(...)``. The merge happens only when a token is currently active
other's metadata; the most recently registered entry wins. *and* the payload carries a ``prompt_id`` using ``prompt_id`` purely as a
marker that this is an execution event meant to receive metadata, not as
the lookup key. This makes the merge immune to ``prompt_id`` collisions.
""" """
if not isinstance(data, dict): if not isinstance(data, dict):
return data return data
prompt_id = data.get("prompt_id") if not data.get("prompt_id"):
if not prompt_id: return data
if active_token is None:
return data return data
with lock: with lock:
stack = registry.get(prompt_id) meta = registry.get(active_token)
meta = stack[-1] if stack else None
if not meta: if not meta:
return data return data
return {**meta, **data} return {**meta, **data}

View File

@ -160,10 +160,10 @@ class WebUIProgressHandler(ProgressHandler):
self.registry = registry self.registry = registry
def _lookup_workflow_id(self, prompt_id: str) -> Optional[str]: def _lookup_workflow_id(self, prompt_id: str) -> Optional[str]:
get_meta = getattr(self.server_instance, "get_prompt_metadata", None) get_meta = getattr(self.server_instance, "get_active_prompt_metadata", None)
if get_meta is None: if get_meta is None:
return None return None
return get_meta(prompt_id).get("workflow_id") return get_meta().get("workflow_id")
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
"""Send the current progress state to the client""" """Send the current progress state to the client"""

View File

@ -36,6 +36,7 @@ from comfy_execution.graph import (
get_input_info, get_input_info,
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
@ -1294,25 +1295,32 @@ class PromptQueue:
with self.mutex: with self.mutex:
return len(self.queue) + len(self.currently_running) return len(self.queue) + len(self.currently_running)
def _extract_metadata_token(self, item):
# Queue item shape: (number, prompt_id, prompt, extra_data, outputs, sensitive)
extra_data = item[3] if len(item) > 3 and isinstance(item[3], dict) else None
if not extra_data:
return None
return extra_data.get(PROMPT_METADATA_TOKEN_KEY)
def wipe_queue(self): def wipe_queue(self):
with self.mutex: with self.mutex:
cancelled_ids = [item[1] for item in self.queue] cancelled_tokens = [self._extract_metadata_token(item) for item in self.queue]
self.queue = [] self.queue = []
self.server.queue_updated() self.server.queue_updated()
for prompt_id in cancelled_ids: for token in cancelled_tokens:
self.server.unregister_prompt_metadata(prompt_id) self.server.unregister_prompt_metadata(token)
def delete_queue_item(self, function): def delete_queue_item(self, function):
with self.mutex: with self.mutex:
for x in range(len(self.queue)): for x in range(len(self.queue)):
if function(self.queue[x]): if function(self.queue[x]):
cancelled_id = self.queue[x][1] cancelled_token = self._extract_metadata_token(self.queue[x])
if len(self.queue) == 1: if len(self.queue) == 1:
self.wipe_queue() self.wipe_queue()
else: else:
self.queue.pop(x) self.queue.pop(x)
heapq.heapify(self.queue) heapq.heapify(self.queue)
self.server.unregister_prompt_metadata(cancelled_id) self.server.unregister_prompt_metadata(cancelled_token)
self.server.queue_updated() self.server.queue_updated()
return True return True
return False return False

11
main.py
View File

@ -27,6 +27,7 @@ from utils.mime_types import init_mime_types
import faulthandler import faulthandler
import logging import logging
import sys import sys
from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY
from comfy_execution.progress import get_progress_state from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags from comfy_api import feature_flags
@ -317,6 +318,13 @@ def prompt_worker(q, server_instance):
for k in sensitive: for k in sensitive:
extra_data[k] = sensitive[k] extra_data[k] = sensitive[k]
# Pin the metadata token registered for this exact prompt while it
# runs so ``send_sync`` can decorate its frames with the right
# ``workflow_id`` even if another submission shares the same
# ``prompt_id``.
metadata_token = extra_data.pop(PROMPT_METADATA_TOKEN_KEY, None)
server_instance.active_prompt_metadata_token = metadata_token
asset_seeder.pause() asset_seeder.pause()
try: try:
e.execute(item[2], prompt_id, extra_data, item[4]) e.execute(item[2], prompt_id, extra_data, item[4])
@ -336,7 +344,8 @@ def prompt_worker(q, server_instance):
# Drop the per-prompt metadata only AFTER the terminal "executing" # Drop the per-prompt metadata only AFTER the terminal "executing"
# send so the registered workflow_id is merged onto that frame. # send so the registered workflow_id is merged onto that frame.
# This is what eliminates the #13684 finally-clear race. # This is what eliminates the #13684 finally-clear race.
server_instance.unregister_prompt_metadata(prompt_id) server_instance.active_prompt_metadata_token = None
server_instance.unregister_prompt_metadata(metadata_token)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time

View File

@ -11,6 +11,7 @@ import execution
import threading import threading
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
from comfy_execution.metadata import ( from comfy_execution.metadata import (
PROMPT_METADATA_TOKEN_KEY,
PromptMetadata, PromptMetadata,
build_prompt_metadata, build_prompt_metadata,
merge_prompt_metadata, merge_prompt_metadata,
@ -259,8 +260,14 @@ class PromptServer():
self.last_node_id = None self.last_node_id = None
self.client_id = None self.client_id = None
self.prompt_metadata: dict[str, list[PromptMetadata]] = {} # Keyed by an internal monotonic token rather than by ``prompt_id``
# because clients can supply the same ``prompt_id`` on retry/dedupe.
self.prompt_metadata: dict[int, PromptMetadata] = {}
self._prompt_metadata_lock = threading.Lock() self._prompt_metadata_lock = threading.Lock()
self._prompt_metadata_counter = 0
# Set by ``main.py``'s queue worker for the duration of one prompt's
# execution; read by ``send_sync`` to merge the right metadata.
self.active_prompt_metadata_token: Optional[int] = None
self.on_prompt_handlers = [] self.on_prompt_handlers = []
@ -289,7 +296,7 @@ class PromptServer():
payload: dict = {"node": self.last_node_id} payload: dict = {"node": self.last_node_id}
if last_prompt_id: if last_prompt_id:
payload["prompt_id"] = last_prompt_id payload["prompt_id"] = last_prompt_id
payload.update(self.get_prompt_metadata(last_prompt_id)) payload.update(self.get_active_prompt_metadata())
await self.send("executing", payload, sid) await self.send("executing", payload, sid)
# Flag to track if we've received the first message # Flag to track if we've received the first message
@ -970,7 +977,9 @@ class PromptServer():
if sensitive_val in extra_data: if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val) sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.register_prompt_metadata(prompt_id, extra_data) token = self.register_prompt_metadata(extra_data)
if token is not None:
extra_data[PROMPT_METADATA_TOKEN_KEY] = token
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response) return web.json_response(response)
@ -1232,40 +1241,50 @@ class PromptServer():
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message) await send_socket_catch_exception(self.sockets[sid].send_json, message)
def register_prompt_metadata(self, prompt_id: str, extra_data: dict) -> None: def register_prompt_metadata(self, extra_data: dict) -> Optional[int]:
"""Capture per-prompt metadata at submission time. """Capture per-prompt metadata at submission time and return a token
identifying this registration.
Stored on the server (not the executor) so it survives independent of Returns ``None`` when there is no recognized metadata, signalling that
the execution thread and can be merged onto outbound WebSocket payloads no token needs to be threaded through the queue. Otherwise the token
in ``send_sync`` without coupling the execution layer to workflow-level must be stored on the queue item (typically via
concepts. :data:`PROMPT_METADATA_TOKEN_KEY` in ``extra_data``) and pinned on the
server as ``active_prompt_metadata_token`` while the prompt runs, so
Stacked per ``prompt_id`` so a client retrying or colliding with the the merge in ``send_sync`` picks up this prompt's metadata even when
same id doesn't have its metadata clobbered or, worse, removed by the another prompt is registered under the same ``prompt_id``.
other prompt's unregister.
""" """
meta = build_prompt_metadata(extra_data) meta = build_prompt_metadata(extra_data)
if not meta: if not meta:
return None
with self._prompt_metadata_lock:
self._prompt_metadata_counter += 1
token = self._prompt_metadata_counter
self.prompt_metadata[token] = meta
return token
def unregister_prompt_metadata(self, token: Optional[int]) -> None:
if token is None:
return return
with self._prompt_metadata_lock: with self._prompt_metadata_lock:
self.prompt_metadata.setdefault(prompt_id, []).append(meta) self.prompt_metadata.pop(token, None)
def unregister_prompt_metadata(self, prompt_id: str) -> None: def get_prompt_metadata_by_token(self, token: Optional[int]) -> PromptMetadata:
if token is None:
return {}
with self._prompt_metadata_lock: with self._prompt_metadata_lock:
stack = self.prompt_metadata.get(prompt_id) return dict(self.prompt_metadata.get(token, {}))
if not stack:
return
stack.pop()
if not stack:
self.prompt_metadata.pop(prompt_id, None)
def get_prompt_metadata(self, prompt_id: str) -> PromptMetadata: def get_active_prompt_metadata(self) -> PromptMetadata:
with self._prompt_metadata_lock: """Snapshot of the metadata for the currently-executing prompt."""
stack = self.prompt_metadata.get(prompt_id) return self.get_prompt_metadata_by_token(self.active_prompt_metadata_token)
return dict(stack[-1]) if stack else {}
def send_sync(self, event, data, sid=None): def send_sync(self, event, data, sid=None):
data = merge_prompt_metadata(self.prompt_metadata, self._prompt_metadata_lock, data) data = merge_prompt_metadata(
self.prompt_metadata,
self._prompt_metadata_lock,
self.active_prompt_metadata_token,
data,
)
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid)) self.messages.put_nowait, (event, data, sid))

View File

@ -2,11 +2,12 @@
through WebSocket events without coupling ``execution.py`` to workflow-level through WebSocket events without coupling ``execution.py`` to workflow-level
concepts. concepts.
The registry is a dict on ``PromptServer`` (keyed by ``prompt_id``) registered The registry is keyed by an internal monotonic token (NOT by ``prompt_id``)
at submission time (``post_prompt``), merged onto outbound payloads in because ``post_prompt`` accepts a client-supplied ``prompt_id`` verbatim and
``send_sync`` via ``merge_prompt_metadata``, and dropped in ``main.py`` *after* two prompts can share an id. ``main.py``'s queue worker pins the active token
the terminal ``executing: {node: None}`` send so the final frame carries the on the server around each ``e.execute(...)`` and the merge in ``send_sync``
same ``workflow_id`` as the rest of the prompt. reads that pinned token, so each prompt's events get its own metadata
regardless of ``prompt_id`` collisions or queue-vs-stack ordering.
""" """
import threading import threading
@ -51,34 +52,43 @@ class TestBuildPromptMetadata:
class TestMergeMetadata: class TestMergeMetadata:
"""``merge_prompt_metadata`` is the transparent injection point used by """``merge_prompt_metadata`` decorates execution events with the metadata
``PromptServer.send_sync``. Event payload fields win on conflict so the for the currently-active token. Event payload fields win on conflict,
executor can never be overridden by stale registry data, and binary payloads binary payloads pass through, and the merge is gated on a ``prompt_id``
(the preview tuple) pass through untouched.""" marker to avoid decorating server-status events like ``status`` /
``queue_updated``."""
def test_merges_workflow_id_when_prompt_id_known(self, registry, lock): def test_merges_for_active_token_when_payload_has_prompt_id(self, registry, lock):
registry["p1"] = [{"workflow_id": "wf-1"}] registry[42] = {"workflow_id": "wf-1"}
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "p1"}) merged = merge_prompt_metadata(registry, lock, 42, {"node": "n1", "prompt_id": "p1"})
assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"} assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"}
def test_passthrough_when_prompt_id_unknown(self, registry, lock): def test_passthrough_when_no_active_token(self, registry, lock):
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "missing"}) registry[42] = {"workflow_id": "wf-1"}
assert merged == {"node": "n1", "prompt_id": "missing"} merged = merge_prompt_metadata(registry, lock, None, {"node": "n1", "prompt_id": "p1"})
assert merged == {"node": "n1", "prompt_id": "p1"}
def test_passthrough_when_active_token_unknown(self, registry, lock):
# Token was unregistered already (or never registered) — merge is a no-op.
merged = merge_prompt_metadata(registry, lock, 99, {"prompt_id": "p1"})
assert merged == {"prompt_id": "p1"}
def test_passthrough_when_no_prompt_id(self, registry, lock): def test_passthrough_when_no_prompt_id(self, registry, lock):
registry["p1"] = [{"workflow_id": "wf-1"}] # Server-status frames (status, queue_updated, etc.) carry no prompt_id
merged = merge_prompt_metadata(registry, lock, {"status": {"queue_remaining": 0}}) # and must not be decorated.
registry[42] = {"workflow_id": "wf-1"}
merged = merge_prompt_metadata(registry, lock, 42, {"status": {"queue_remaining": 0}})
assert merged == {"status": {"queue_remaining": 0}} assert merged == {"status": {"queue_remaining": 0}}
def test_passthrough_for_non_dict_payload(self, registry, lock): def test_passthrough_for_non_dict_payload(self, registry, lock):
registry["p1"] = [{"workflow_id": "wf-1"}] registry[42] = {"workflow_id": "wf-1"}
binary = (b"image-bytes", {"prompt_id": "p1"}) binary = (b"image-bytes", {"prompt_id": "p1"})
assert merge_prompt_metadata(registry, lock, binary) is binary assert merge_prompt_metadata(registry, lock, 42, binary) is binary
def test_event_payload_wins_over_registered_metadata(self, registry, lock): def test_event_payload_wins_over_registered_metadata(self, registry, lock):
registry["p1"] = [{"workflow_id": "wf-registered"}] registry[42] = {"workflow_id": "wf-registered"}
merged = merge_prompt_metadata( merged = merge_prompt_metadata(
registry, lock, {"prompt_id": "p1", "workflow_id": "wf-caller"} registry, lock, 42, {"prompt_id": "p1", "workflow_id": "wf-caller"}
) )
assert merged["workflow_id"] == "wf-caller" assert merged["workflow_id"] == "wf-caller"
@ -86,8 +96,8 @@ class TestMergeMetadata:
class TestProgressTextSidResolution: class TestProgressTextSidResolution:
"""``BinaryEventTypes.TEXT`` frames don't yet carry ``prompt_id`` / """``BinaryEventTypes.TEXT`` frames don't yet carry ``prompt_id`` /
``workflow_id`` in their wire shape, so cross-client routing has to happen ``workflow_id`` in their wire shape, so cross-client routing has to happen
at the ``sid`` level instead of via the metadata merge. The default sid at the ``sid`` level. The default sid pins the broadcast to the active
pins the broadcast to the active prompt's client. prompt's client.
""" """
def test_explicit_sid_passes_through(self): def test_explicit_sid_passes_through(self):
@ -97,77 +107,76 @@ class TestProgressTextSidResolution:
assert resolve_progress_text_sid(None, "client-active") == "client-active" assert resolve_progress_text_sid(None, "client-active") == "client-active"
def test_none_sid_with_no_active_client_stays_none(self): def test_none_sid_with_no_active_client_stays_none(self):
# No active prompt means there is no sensible recipient — fall through
# to broadcast rather than fabricate a target.
assert resolve_progress_text_sid(None, None) is None assert resolve_progress_text_sid(None, None) is None
class TestPromptIdCollision: class TestPromptIdCollisionWithTokens:
"""Two prompts may be submitted with the same ``prompt_id`` (client retry, """Two prompts can be queued with the same client-supplied ``prompt_id``.
forced custom id, partner-integration deduplication, etc.). With a flat With a registry keyed by ``prompt_id`` the second registration would
dict-keyed registry the second registration would clobber the first and a overwrite the first or be erased by the first's unregister. The token model
single ``unregister`` call would erase metadata still needed by the other makes each registration independent."""
prompt. The stack-based registry resolves both cases."""
def test_duplicate_register_does_not_clobber_prior_entry(self, registry, lock): def test_two_submissions_get_distinct_tokens_and_each_merges_correctly(self, registry, lock):
# Caller B clobbers A in the merge view (last-wins), but A's metadata # Two submissions of the same prompt_id with different workflow_ids.
# is still in the stack and reappears after B unregisters. registry[1] = {"workflow_id": "wf-A"} # token from submission #1
registry.setdefault("p1", []).append({"workflow_id": "wf-A"}) registry[2] = {"workflow_id": "wf-B"} # token from submission #2
registry["p1"].append({"workflow_id": "wf-B"})
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"}) # Worker is currently running submission #1.
assert merged["workflow_id"] == "wf-B" merged = merge_prompt_metadata(registry, lock, 1, {"prompt_id": "P", "node": "x"})
registry["p1"].pop()
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"})
assert merged["workflow_id"] == "wf-A" assert merged["workflow_id"] == "wf-A"
def test_single_unregister_does_not_drop_concurrent_submission(self, registry, lock): # Worker switches to submission #2 (queue ordering, retry, whatever).
registry.setdefault("p1", []).append({"workflow_id": "wf-A"}) merged = merge_prompt_metadata(registry, lock, 2, {"prompt_id": "P", "node": "y"})
registry["p1"].append({"workflow_id": "wf-B"}) assert merged["workflow_id"] == "wf-B"
# Only one of the two prompts finished — pop once. def test_unregister_by_token_does_not_drop_concurrent_submission(self, registry, lock):
registry["p1"].pop() registry[1] = {"workflow_id": "wf-A"}
registry[2] = {"workflow_id": "wf-B"}
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"}) # Submission #1 finishes — drop its token only.
assert "workflow_id" in merged registry.pop(1, None)
def test_full_drain_clears_registry(self, registry, lock): # Submission #2 still has its metadata.
registry.setdefault("p1", []).append({"workflow_id": "wf-A"}) merged = merge_prompt_metadata(registry, lock, 2, {"prompt_id": "P"})
registry["p1"].append({"workflow_id": "wf-B"}) assert merged["workflow_id"] == "wf-B"
registry["p1"].pop()
registry["p1"].pop()
merged = merge_prompt_metadata(registry, lock, {"prompt_id": "p1"}) def test_execution_order_independent_of_registration_order(self, registry, lock):
assert "workflow_id" not in merged """Regression for the LIFO-stack failure mode: queue executes #1 first
but the previous stack design would have made the merge pick the
latest-registered (#2) metadata. Token model is immune."""
registry[1] = {"workflow_id": "wf-A"}
registry[2] = {"workflow_id": "wf-B"}
# Even though #2 was registered after #1, executing #1 still sees wf-A.
merged_first_running = merge_prompt_metadata(registry, lock, 1, {"prompt_id": "P"})
assert merged_first_running["workflow_id"] == "wf-A"
class TestRaceRegressionForTerminalExecutingFrame: class TestRaceRegressionForTerminalExecutingFrame:
"""Regression for the PR #13684 finally-clear race. """Regression for the PR #13684 finally-clear race.
In the reverted PR, the executor's ``finally`` cleared ``last_workflow_id`` Executor's ``finally`` previously cleared the workflow_id source, so the
before ``main.py`` emitted the post-completion ``executing: {node: None}`` post-completion terminal frame shipped ``workflow_id=None``. With the
frame so that terminal frame shipped ``workflow_id=None``. token model, the active token stays pinned until ``main.py`` clears it
*after* the terminal send.
With the registry approach, ``main.py`` unregisters *after* the terminal
send, so the merge sees the registered metadata. This test simulates that
ordering to lock in the contract.
""" """
def test_terminal_executing_frame_includes_workflow_id(self, registry, lock): def test_terminal_executing_frame_includes_workflow_id(self, registry, lock):
registry["p1"] = [{"workflow_id": "wf-1"}] registry[7] = {"workflow_id": "wf-1"}
active_token = 7
# main.py emits the terminal frame BEFORE unregistering. # main.py emits the terminal frame BEFORE clearing the active token.
terminal = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"}) terminal = merge_prompt_metadata(
registry["p1"].pop() registry, lock, active_token, {"node": None, "prompt_id": "p1"}
if not registry["p1"]: )
registry.pop("p1", None) # main.py's finally: clear active token + unregister.
active_token = None
registry.pop(7, None)
assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"} assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"}
# After unregister, any straggler events emitted by extensions after # After cleanup, straggler events get no metadata.
# completion are no longer decorated. Verifies the registry is actually straggler = merge_prompt_metadata(
# released, not just shadowed. registry, lock, active_token, {"node": None, "prompt_id": "p1"}
straggler = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"}) )
assert "workflow_id" not in straggler assert "workflow_id" not in straggler

View File

@ -0,0 +1,164 @@
"""End-to-end checks of the prompt-metadata propagation.
These tests instantiate the real ``PromptServer`` (no heavy mocking) and drive
the same call sequence ``main.py`` would: register metadata, pin the token,
emit a bunch of execution-shaped frames through ``send_sync``, then clear the
token + unregister. We then drain the server's message queue and inspect the
exact frames that would have been written to the WebSocket proving the wire
shape and the ``prompt_id`` collision invariants without needing a browser.
"""
import asyncio
import os
import sys
import threading
import pytest
# Ensure the repo root is the first thing on sys.path so ``utils`` resolves to
# the ComfyUI package, not a site-packages namespace package. ``utils`` must
# load before ``app.frontend_management`` imports it.
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
import utils # noqa: F401 -- pin the package before later imports
from comfy_execution.metadata import PROMPT_METADATA_TOKEN_KEY
server_module = pytest.importorskip("server")
@pytest.fixture
def prompt_server():
loop = asyncio.new_event_loop()
try:
server = server_module.PromptServer(loop)
yield server
finally:
loop.close()
def _drain(server) -> list:
"""Empty the server's internal asyncio.Queue and return every (event, data, sid)."""
frames = []
queue = server.messages
while not queue.empty():
frames.append(queue.get_nowait())
return frames
def _run_pending_callbacks(server):
"""``send_sync`` schedules ``messages.put_nowait`` via ``loop.call_soon_threadsafe``.
Pump the loop a single iteration so those callbacks land before we drain."""
server.loop.call_soon(server.loop.stop)
server.loop.run_forever()
def _simulate_prompt_lifecycle(server, extra_data: dict, prompt_id: str):
"""Replicates what ``main.py``'s queue worker does around ``e.execute(...)``
plus a representative sample of the events ``execution.py`` emits."""
token = server.register_prompt_metadata(extra_data)
if token is not None:
extra_data[PROMPT_METADATA_TOKEN_KEY] = token
# Worker picks up the item, pins the token.
server.active_prompt_metadata_token = extra_data.pop(PROMPT_METADATA_TOKEN_KEY, None)
client_id = "client-A"
server.client_id = client_id
# A representative sample of the eight execution events listed in PR #13684.
server.send_sync("execution_start", {"prompt_id": prompt_id}, client_id)
server.send_sync("execution_cached", {"nodes": [], "prompt_id": prompt_id}, client_id)
server.send_sync("executing", {"node": "n1", "display_node": "n1", "prompt_id": prompt_id}, client_id)
server.send_sync(
"executed",
{"node": "n1", "display_node": "n1", "output": {}, "prompt_id": prompt_id},
client_id,
)
server.send_sync("execution_success", {"prompt_id": prompt_id}, client_id)
# The terminal frame ``main.py`` itself emits, just before clearing state.
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, client_id)
# Worker finally clause: clear token then unregister.
server.active_prompt_metadata_token = None
server.unregister_prompt_metadata(token)
class TestEndToEndWorkflowIdOnAllExecutionEvents:
def test_every_execution_event_carries_workflow_id(self, prompt_server):
_simulate_prompt_lifecycle(
prompt_server,
extra_data={"extra_pnginfo": {"workflow": {"id": "wf-xyz"}}},
prompt_id="p-1",
)
_run_pending_callbacks(prompt_server)
frames = _drain(prompt_server)
assert frames, "no frames emitted"
# Every frame with a prompt_id payload must carry workflow_id top-level.
for event, data, _sid in frames:
if isinstance(data, dict) and data.get("prompt_id"):
assert data.get("workflow_id") == "wf-xyz", (event, data)
# And specifically verify the terminal frame — the #13684 race victim.
terminal = [
(e, d) for e, d, _ in frames
if e == "executing" and isinstance(d, dict) and d.get("node") is None
]
assert terminal, "no terminal executing frame emitted"
_, terminal_payload = terminal[-1]
assert terminal_payload["workflow_id"] == "wf-xyz"
def test_status_frame_is_not_decorated(self, prompt_server):
prompt_server.register_prompt_metadata({"extra_pnginfo": {"workflow": {"id": "wf-1"}}})
prompt_server.active_prompt_metadata_token = 1
# A status frame has no prompt_id and must remain untouched.
prompt_server.send_sync("status", {"status": {"queue_remaining": 0}}, None)
_run_pending_callbacks(prompt_server)
frames = _drain(prompt_server)
assert any(e == "status" for e, _, _ in frames)
for event, data, _ in frames:
if event == "status":
assert "workflow_id" not in data
class TestEndToEndPromptIdCollision:
"""Drive two prompts with the same client-supplied ``prompt_id`` but
different ``workflow_id`` values. The token model must keep their event
streams attributed correctly."""
def test_same_prompt_id_two_workflows_each_stream_attributed_correctly(self, prompt_server):
_simulate_prompt_lifecycle(
prompt_server,
extra_data={"extra_pnginfo": {"workflow": {"id": "wf-A"}}},
prompt_id="P-shared",
)
_simulate_prompt_lifecycle(
prompt_server,
extra_data={"extra_pnginfo": {"workflow": {"id": "wf-B"}}},
prompt_id="P-shared",
)
_run_pending_callbacks(prompt_server)
frames = _drain(prompt_server)
# Partition the frames by lifecycle (execution_start..terminal-executing).
runs: list[list[dict]] = [[]]
for event, data, _ in frames:
if not isinstance(data, dict):
continue
runs[-1].append({"event": event, **data})
if event == "executing" and data.get("node") is None:
runs.append([])
# The trailing empty bucket from the final split is fine.
runs = [r for r in runs if r]
assert len(runs) == 2, runs
run_a, run_b = runs
for entry in run_a:
if entry.get("prompt_id"):
assert entry.get("workflow_id") == "wf-A", entry
for entry in run_b:
if entry.get("prompt_id"):
assert entry.get("workflow_id") == "wf-B", entry