Propagate workflow_id via per-prompt metadata registry (FE-745)

PR #13684 added workflow_id directly to ~9 dict literals across execution.py,
progress.py and main.py, along with executor.workflow_id and
server.last_workflow_id state. It was reverted because the execution layer
should not know about workflow concepts and because a finally-clear race
emitted workflow_id=None on the terminal "executing" frame.

Instead, register per-prompt metadata on PromptServer at submission time
and merge it onto outbound WebSocket payloads inside send_sync. The merge
keys off prompt_id (already present on every execution event), so
execution.py stays workflow-agnostic. Metadata is unregistered in main.py's
queue loop AFTER the terminal executing send, which structurally removes
the race.

- New comfy_execution/metadata.py: PromptMetadata TypedDict +
  build_prompt_metadata + merge_prompt_metadata helpers.
- PromptServer: prompt_metadata registry (lock-protected), register on
  post_prompt, merge in send_sync, expose get_prompt_metadata.
- jobs.py: extracted extract_workflow_id with strict isinstance guards;
  _extract_job_metadata delegates.
- main.py: try/finally around the queue iteration; unregister after the
  terminal "executing: {node: None}" send.
- execution.py PromptQueue: drop registry entries on wipe_queue /
  delete_queue_item so cancellations don't leak.
- progress.py: look up workflow_id from the server registry for the
  per-node nested copies and the binary preview metadata, matching #13684's
  wire shape so the frontend needs no changes.
- Tests: tests-unit/server_test/test_prompt_metadata.py covers the merge,
  the passthrough cases (no prompt_id, unknown prompt_id, binary payloads),
  and the terminal-frame race regression.
This commit is contained in:
dante01yoon 2026-05-19 17:16:11 +09:00
parent 6b61918a16
commit 5396b4fe67
7 changed files with 254 additions and 15 deletions

View File

@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
}
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
"""Extract the workflow id from a prompt's ``extra_data``.
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
when a prompt is queued. Any value that is not a non-empty string is treated as
missing so callers can rely on the return being either ``None`` or a string.
"""
if not isinstance(extra_data, dict):
return None
extra_pnginfo = extra_data.get('extra_pnginfo')
if not isinstance(extra_pnginfo, dict):
return None
workflow = extra_pnginfo.get('workflow')
if not isinstance(workflow, dict):
return None
workflow_id = workflow.get('id')
if isinstance(workflow_id, str) and workflow_id:
return workflow_id
return None
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
workflow_id = extract_workflow_id(extra_data)
return create_time, workflow_id

View File

@ -0,0 +1,51 @@
"""Per-prompt metadata propagated alongside execution WebSocket events.
The execution layer (``execution.py``) is intentionally kept agnostic of
workflow-level concepts. Instead, ``PromptServer`` registers per-``prompt_id``
metadata at submission time and merges it onto outgoing WebSocket payloads in
``send_sync``. Today only ``workflow_id`` is propagated; the structure is a
``TypedDict`` so additional keys can be added without churn at the call sites.
"""
import threading
from typing import Optional, TypedDict
from comfy_execution.jobs import extract_workflow_id
class PromptMetadata(TypedDict, total=False):
workflow_id: Optional[str]
def build_prompt_metadata(extra_data: Optional[dict]) -> PromptMetadata:
"""Build a ``PromptMetadata`` snapshot from a prompt's ``extra_data``.
Returns an empty dict when no recognized metadata is present so callers can
skip registering anything in that case.
"""
meta: PromptMetadata = {}
workflow_id = extract_workflow_id(extra_data)
if workflow_id is not None:
meta["workflow_id"] = workflow_id
return meta
def merge_prompt_metadata(
registry: dict,
lock: threading.Lock,
data,
):
"""Return ``data`` with the registered metadata for its ``prompt_id`` merged
top-level. The event payload wins on conflict, and non-dict payloads (e.g.
the binary preview tuple) pass through untouched.
"""
if not isinstance(data, dict):
return data
prompt_id = data.get("prompt_id")
if not prompt_id:
return data
with lock:
meta = registry.get(prompt_id)
if not meta:
return data
return {**meta, **data}

View File

@ -159,11 +159,19 @@ class WebUIProgressHandler(ProgressHandler):
def set_registry(self, registry: "ProgressRegistry"):
self.registry = registry
def _lookup_workflow_id(self, prompt_id: str) -> Optional[str]:
get_meta = getattr(self.server_instance, "get_prompt_metadata", None)
if get_meta is None:
return None
return get_meta(prompt_id).get("workflow_id")
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
"""Send the current progress state to the client"""
if self.server_instance is None:
return
workflow_id = self._lookup_workflow_id(prompt_id)
# Only send info for non-pending nodes
active_nodes = {
node_id: {
@ -172,6 +180,7 @@ class WebUIProgressHandler(ProgressHandler):
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"workflow_id": workflow_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
@ -181,7 +190,10 @@ class WebUIProgressHandler(ProgressHandler):
}
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
# Include client_id to ensure message is only sent to the initiating client.
# The outer ``workflow_id`` is merged in by ``PromptServer.send_sync`` via
# the per-prompt metadata registry; the nested copy on each node entry
# mirrors the wire shape consumed by the frontend.
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
)
@ -215,6 +227,7 @@ class WebUIProgressHandler(ProgressHandler):
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"workflow_id": self._lookup_workflow_id(prompt_id),
"display_node_id": self.registry.dynprompt.get_display_node_id(
node_id
),

View File

@ -1296,18 +1296,23 @@ class PromptQueue:
def wipe_queue(self):
with self.mutex:
cancelled_ids = [item[1] for item in self.queue]
self.queue = []
self.server.queue_updated()
for prompt_id in cancelled_ids:
self.server.unregister_prompt_metadata(prompt_id)
def delete_queue_item(self, function):
with self.mutex:
for x in range(len(self.queue)):
if function(self.queue[x]):
cancelled_id = self.queue[x][1]
if len(self.queue) == 1:
self.wipe_queue()
else:
self.queue.pop(x)
heapq.heapify(self.queue)
self.server.unregister_prompt_metadata(cancelled_id)
self.server.queue_updated()
return True
return False

28
main.py
View File

@ -318,19 +318,25 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
try:
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
q.task_done(item_id,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
finally:
# Drop the per-prompt metadata only AFTER the terminal "executing"
# send so the registered workflow_id is merged onto that frame.
# This is what eliminates the #13684 finally-clear race.
server_instance.unregister_prompt_metadata(prompt_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time

View File

@ -8,7 +8,9 @@ import time
import nodes
import folder_paths
import execution
import threading
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
from comfy_execution.metadata import PromptMetadata, build_prompt_metadata, merge_prompt_metadata
import uuid
import urllib
import json
@ -252,6 +254,9 @@ class PromptServer():
self.last_node_id = None
self.client_id = None
self.prompt_metadata: dict[str, PromptMetadata] = {}
self._prompt_metadata_lock = threading.Lock()
self.on_prompt_handlers = []
@routes.get('/ws')
@ -275,7 +280,12 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
last_prompt_id = getattr(self, "last_prompt_id", None)
payload: dict = {"node": self.last_node_id}
if last_prompt_id:
payload["prompt_id"] = last_prompt_id
payload.update(self.get_prompt_metadata(last_prompt_id))
await self.send("executing", payload, sid)
# Flag to track if we've received the first message
first_message = True
@ -955,6 +965,7 @@ class PromptServer():
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.register_prompt_metadata(prompt_id, extra_data)
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
@ -1216,7 +1227,30 @@ class PromptServer():
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message)
def register_prompt_metadata(self, prompt_id: str, extra_data: dict) -> None:
"""Capture per-prompt metadata at submission time.
Stored on the server (not the executor) so it survives independent of
the execution thread and can be merged onto outbound WebSocket payloads
in ``send_sync`` without coupling the execution layer to workflow-level
concepts.
"""
meta = build_prompt_metadata(extra_data)
if not meta:
return
with self._prompt_metadata_lock:
self.prompt_metadata[prompt_id] = meta
def unregister_prompt_metadata(self, prompt_id: str) -> None:
with self._prompt_metadata_lock:
self.prompt_metadata.pop(prompt_id, None)
def get_prompt_metadata(self, prompt_id: str) -> PromptMetadata:
with self._prompt_metadata_lock:
return dict(self.prompt_metadata.get(prompt_id, {}))
def send_sync(self, event, data, sid=None):
data = merge_prompt_metadata(self.prompt_metadata, self._prompt_metadata_lock, data)
self.loop.call_soon_threadsafe(
self.messages.put_nowait, (event, data, sid))

View File

@ -0,0 +1,110 @@
"""Tests for the per-prompt metadata registry used to propagate ``workflow_id``
through WebSocket events without coupling ``execution.py`` to workflow-level
concepts.
The registry is a dict on ``PromptServer`` (keyed by ``prompt_id``) registered
at submission time (``post_prompt``), merged onto outbound payloads in
``send_sync`` via ``merge_prompt_metadata``, and dropped in ``main.py`` *after*
the terminal ``executing: {node: None}`` send so the final frame carries the
same ``workflow_id`` as the rest of the prompt.
"""
import threading
import pytest
from comfy_execution.metadata import (
build_prompt_metadata,
merge_prompt_metadata,
)
@pytest.fixture
def registry():
return {}
@pytest.fixture
def lock():
return threading.Lock()
class TestBuildPromptMetadata:
def test_returns_workflow_id_when_present(self):
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
assert build_prompt_metadata(extra_data) == {"workflow_id": "wf-1"}
def test_empty_when_workflow_id_missing(self):
assert build_prompt_metadata({}) == {}
assert build_prompt_metadata({"extra_pnginfo": {}}) == {}
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {}}}) == {}
def test_empty_when_workflow_id_not_a_non_empty_string(self):
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": ""}}}) == {}
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": 42}}}) == {}
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": None}}}) == {}
def test_empty_on_non_dict_input(self):
assert build_prompt_metadata(None) == {}
assert build_prompt_metadata("not a dict") == {}
class TestMergeMetadata:
"""``merge_prompt_metadata`` is the transparent injection point used by
``PromptServer.send_sync``. Event payload fields win on conflict so the
executor can never be overridden by stale registry data, and binary payloads
(the preview tuple) pass through untouched."""
def test_merges_workflow_id_when_prompt_id_known(self, registry, lock):
registry["p1"] = {"workflow_id": "wf-1"}
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "p1"})
assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"}
def test_passthrough_when_prompt_id_unknown(self, registry, lock):
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "missing"})
assert merged == {"node": "n1", "prompt_id": "missing"}
def test_passthrough_when_no_prompt_id(self, registry, lock):
registry["p1"] = {"workflow_id": "wf-1"}
merged = merge_prompt_metadata(registry, lock, {"status": {"queue_remaining": 0}})
assert merged == {"status": {"queue_remaining": 0}}
def test_passthrough_for_non_dict_payload(self, registry, lock):
registry["p1"] = {"workflow_id": "wf-1"}
binary = (b"image-bytes", {"prompt_id": "p1"})
assert merge_prompt_metadata(registry, lock, binary) is binary
def test_event_payload_wins_over_registered_metadata(self, registry, lock):
registry["p1"] = {"workflow_id": "wf-registered"}
merged = merge_prompt_metadata(
registry, lock, {"prompt_id": "p1", "workflow_id": "wf-caller"}
)
assert merged["workflow_id"] == "wf-caller"
class TestRaceRegressionForTerminalExecutingFrame:
"""Regression for the PR #13684 finally-clear race.
In the reverted PR, the executor's ``finally`` cleared ``last_workflow_id``
before ``main.py`` emitted the post-completion ``executing: {node: None}``
frame so that terminal frame shipped ``workflow_id=None``.
With the registry approach, ``main.py`` unregisters *after* the terminal
send, so the merge sees the registered metadata. This test simulates that
ordering to lock in the contract.
"""
def test_terminal_executing_frame_includes_workflow_id(self, registry, lock):
registry["p1"] = {"workflow_id": "wf-1"}
# main.py emits the terminal frame BEFORE unregistering.
terminal = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"})
registry.pop("p1", None) # main.py's finally: unregister_prompt_metadata
assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"}
# After unregister, any straggler events emitted by extensions after
# completion are no longer decorated. Verifies the registry is actually
# released, not just shadowed.
straggler = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"})
assert "workflow_id" not in straggler