mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 21:39:45 +08:00
Propagate workflow_id via per-prompt metadata registry (FE-745)
PR #13684 added workflow_id directly to ~9 dict literals across execution.py, progress.py and main.py, along with executor.workflow_id and server.last_workflow_id state. It was reverted because the execution layer should not know about workflow concepts and because a finally-clear race emitted workflow_id=None on the terminal "executing" frame. Instead, register per-prompt metadata on PromptServer at submission time and merge it onto outbound WebSocket payloads inside send_sync. The merge keys off prompt_id (already present on every execution event), so execution.py stays workflow-agnostic. Metadata is unregistered in main.py's queue loop AFTER the terminal executing send, which structurally removes the race. - New comfy_execution/metadata.py: PromptMetadata TypedDict + build_prompt_metadata + merge_prompt_metadata helpers. - PromptServer: prompt_metadata registry (lock-protected), register on post_prompt, merge in send_sync, expose get_prompt_metadata. - jobs.py: extracted extract_workflow_id with strict isinstance guards; _extract_job_metadata delegates. - main.py: try/finally around the queue iteration; unregister after the terminal "executing: {node: None}" send. - execution.py PromptQueue: drop registry entries on wipe_queue / delete_queue_item so cancellations don't leak. - progress.py: look up workflow_id from the server registry for the per-node nested copies and the binary preview metadata, matching #13684's wire shape so the frontend needs no changes. - Tests: tests-unit/server_test/test_prompt_metadata.py covers the merge, the passthrough cases (no prompt_id, unknown prompt_id, binary payloads), and the terminal-frame race regression.
This commit is contained in:
parent
6b61918a16
commit
5396b4fe67
@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||
"""Extract the workflow id from a prompt's ``extra_data``.
|
||||
|
||||
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
|
||||
when a prompt is queued. Any value that is not a non-empty string is treated as
|
||||
missing so callers can rely on the return being either ``None`` or a string.
|
||||
"""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||
if not isinstance(extra_pnginfo, dict):
|
||||
return None
|
||||
workflow = extra_pnginfo.get('workflow')
|
||||
if not isinstance(workflow, dict):
|
||||
return None
|
||||
workflow_id = workflow.get('id')
|
||||
if isinstance(workflow_id, str) and workflow_id:
|
||||
return workflow_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
|
||||
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
||||
tuple: (create_time, workflow_id)
|
||||
"""
|
||||
create_time = extra_data.get('create_time')
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
return create_time, workflow_id
|
||||
|
||||
|
||||
|
||||
51
comfy_execution/metadata.py
Normal file
51
comfy_execution/metadata.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Per-prompt metadata propagated alongside execution WebSocket events.
|
||||
|
||||
The execution layer (``execution.py``) is intentionally kept agnostic of
|
||||
workflow-level concepts. Instead, ``PromptServer`` registers per-``prompt_id``
|
||||
metadata at submission time and merges it onto outgoing WebSocket payloads in
|
||||
``send_sync``. Today only ``workflow_id`` is propagated; the structure is a
|
||||
``TypedDict`` so additional keys can be added without churn at the call sites.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from comfy_execution.jobs import extract_workflow_id
|
||||
|
||||
|
||||
class PromptMetadata(TypedDict, total=False):
|
||||
workflow_id: Optional[str]
|
||||
|
||||
|
||||
def build_prompt_metadata(extra_data: Optional[dict]) -> PromptMetadata:
|
||||
"""Build a ``PromptMetadata`` snapshot from a prompt's ``extra_data``.
|
||||
|
||||
Returns an empty dict when no recognized metadata is present so callers can
|
||||
skip registering anything in that case.
|
||||
"""
|
||||
meta: PromptMetadata = {}
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
if workflow_id is not None:
|
||||
meta["workflow_id"] = workflow_id
|
||||
return meta
|
||||
|
||||
|
||||
def merge_prompt_metadata(
|
||||
registry: dict,
|
||||
lock: threading.Lock,
|
||||
data,
|
||||
):
|
||||
"""Return ``data`` with the registered metadata for its ``prompt_id`` merged
|
||||
top-level. The event payload wins on conflict, and non-dict payloads (e.g.
|
||||
the binary preview tuple) pass through untouched.
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
prompt_id = data.get("prompt_id")
|
||||
if not prompt_id:
|
||||
return data
|
||||
with lock:
|
||||
meta = registry.get(prompt_id)
|
||||
if not meta:
|
||||
return data
|
||||
return {**meta, **data}
|
||||
@ -159,11 +159,19 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
def set_registry(self, registry: "ProgressRegistry"):
|
||||
self.registry = registry
|
||||
|
||||
def _lookup_workflow_id(self, prompt_id: str) -> Optional[str]:
|
||||
get_meta = getattr(self.server_instance, "get_prompt_metadata", None)
|
||||
if get_meta is None:
|
||||
return None
|
||||
return get_meta(prompt_id).get("workflow_id")
|
||||
|
||||
def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]):
|
||||
"""Send the current progress state to the client"""
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
workflow_id = self._lookup_workflow_id(prompt_id)
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
@ -172,6 +180,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": workflow_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
@ -181,7 +190,10 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
}
|
||||
|
||||
# Send a combined progress_state message with all node states
|
||||
# Include client_id to ensure message is only sent to the initiating client
|
||||
# Include client_id to ensure message is only sent to the initiating client.
|
||||
# The outer ``workflow_id`` is merged in by ``PromptServer.send_sync`` via
|
||||
# the per-prompt metadata registry; the nested copy on each node entry
|
||||
# mirrors the wire shape consumed by the frontend.
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
)
|
||||
@ -215,6 +227,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self._lookup_workflow_id(prompt_id),
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||
node_id
|
||||
),
|
||||
|
||||
@ -1296,18 +1296,23 @@ class PromptQueue:
|
||||
|
||||
def wipe_queue(self):
|
||||
with self.mutex:
|
||||
cancelled_ids = [item[1] for item in self.queue]
|
||||
self.queue = []
|
||||
self.server.queue_updated()
|
||||
for prompt_id in cancelled_ids:
|
||||
self.server.unregister_prompt_metadata(prompt_id)
|
||||
|
||||
def delete_queue_item(self, function):
|
||||
with self.mutex:
|
||||
for x in range(len(self.queue)):
|
||||
if function(self.queue[x]):
|
||||
cancelled_id = self.queue[x][1]
|
||||
if len(self.queue) == 1:
|
||||
self.wipe_queue()
|
||||
else:
|
||||
self.queue.pop(x)
|
||||
heapq.heapify(self.queue)
|
||||
self.server.unregister_prompt_metadata(cancelled_id)
|
||||
self.server.queue_updated()
|
||||
return True
|
||||
return False
|
||||
|
||||
28
main.py
28
main.py
@ -318,19 +318,25 @@ def prompt_worker(q, server_instance):
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
need_gc = True
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
finally:
|
||||
# Drop the per-prompt metadata only AFTER the terminal "executing"
|
||||
# send so the registered workflow_id is merged onto that frame.
|
||||
# This is what eliminates the #13684 finally-clear race.
|
||||
server_instance.unregister_prompt_metadata(prompt_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
36
server.py
36
server.py
@ -8,7 +8,9 @@ import time
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
import threading
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||
from comfy_execution.metadata import PromptMetadata, build_prompt_metadata, merge_prompt_metadata
|
||||
import uuid
|
||||
import urllib
|
||||
import json
|
||||
@ -252,6 +254,9 @@ class PromptServer():
|
||||
self.last_node_id = None
|
||||
self.client_id = None
|
||||
|
||||
self.prompt_metadata: dict[str, PromptMetadata] = {}
|
||||
self._prompt_metadata_lock = threading.Lock()
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
@routes.get('/ws')
|
||||
@ -275,7 +280,12 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
last_prompt_id = getattr(self, "last_prompt_id", None)
|
||||
payload: dict = {"node": self.last_node_id}
|
||||
if last_prompt_id:
|
||||
payload["prompt_id"] = last_prompt_id
|
||||
payload.update(self.get_prompt_metadata(last_prompt_id))
|
||||
await self.send("executing", payload, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
@ -955,6 +965,7 @@ class PromptServer():
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||
self.register_prompt_metadata(prompt_id, extra_data)
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
@ -1216,7 +1227,30 @@ class PromptServer():
|
||||
elif sid in self.sockets:
|
||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||
|
||||
def register_prompt_metadata(self, prompt_id: str, extra_data: dict) -> None:
|
||||
"""Capture per-prompt metadata at submission time.
|
||||
|
||||
Stored on the server (not the executor) so it survives independent of
|
||||
the execution thread and can be merged onto outbound WebSocket payloads
|
||||
in ``send_sync`` without coupling the execution layer to workflow-level
|
||||
concepts.
|
||||
"""
|
||||
meta = build_prompt_metadata(extra_data)
|
||||
if not meta:
|
||||
return
|
||||
with self._prompt_metadata_lock:
|
||||
self.prompt_metadata[prompt_id] = meta
|
||||
|
||||
def unregister_prompt_metadata(self, prompt_id: str) -> None:
|
||||
with self._prompt_metadata_lock:
|
||||
self.prompt_metadata.pop(prompt_id, None)
|
||||
|
||||
def get_prompt_metadata(self, prompt_id: str) -> PromptMetadata:
|
||||
with self._prompt_metadata_lock:
|
||||
return dict(self.prompt_metadata.get(prompt_id, {}))
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
data = merge_prompt_metadata(self.prompt_metadata, self._prompt_metadata_lock, data)
|
||||
self.loop.call_soon_threadsafe(
|
||||
self.messages.put_nowait, (event, data, sid))
|
||||
|
||||
|
||||
110
tests-unit/server_test/test_prompt_metadata.py
Normal file
110
tests-unit/server_test/test_prompt_metadata.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""Tests for the per-prompt metadata registry used to propagate ``workflow_id``
|
||||
through WebSocket events without coupling ``execution.py`` to workflow-level
|
||||
concepts.
|
||||
|
||||
The registry is a dict on ``PromptServer`` (keyed by ``prompt_id``) registered
|
||||
at submission time (``post_prompt``), merged onto outbound payloads in
|
||||
``send_sync`` via ``merge_prompt_metadata``, and dropped in ``main.py`` *after*
|
||||
the terminal ``executing: {node: None}`` send so the final frame carries the
|
||||
same ``workflow_id`` as the rest of the prompt.
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_execution.metadata import (
|
||||
build_prompt_metadata,
|
||||
merge_prompt_metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lock():
|
||||
return threading.Lock()
|
||||
|
||||
|
||||
class TestBuildPromptMetadata:
|
||||
def test_returns_workflow_id_when_present(self):
|
||||
extra_data = {"extra_pnginfo": {"workflow": {"id": "wf-1"}}}
|
||||
assert build_prompt_metadata(extra_data) == {"workflow_id": "wf-1"}
|
||||
|
||||
def test_empty_when_workflow_id_missing(self):
|
||||
assert build_prompt_metadata({}) == {}
|
||||
assert build_prompt_metadata({"extra_pnginfo": {}}) == {}
|
||||
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {}}}) == {}
|
||||
|
||||
def test_empty_when_workflow_id_not_a_non_empty_string(self):
|
||||
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": ""}}}) == {}
|
||||
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": 42}}}) == {}
|
||||
assert build_prompt_metadata({"extra_pnginfo": {"workflow": {"id": None}}}) == {}
|
||||
|
||||
def test_empty_on_non_dict_input(self):
|
||||
assert build_prompt_metadata(None) == {}
|
||||
assert build_prompt_metadata("not a dict") == {}
|
||||
|
||||
|
||||
class TestMergeMetadata:
|
||||
"""``merge_prompt_metadata`` is the transparent injection point used by
|
||||
``PromptServer.send_sync``. Event payload fields win on conflict so the
|
||||
executor can never be overridden by stale registry data, and binary payloads
|
||||
(the preview tuple) pass through untouched."""
|
||||
|
||||
def test_merges_workflow_id_when_prompt_id_known(self, registry, lock):
|
||||
registry["p1"] = {"workflow_id": "wf-1"}
|
||||
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "p1"})
|
||||
assert merged == {"node": "n1", "prompt_id": "p1", "workflow_id": "wf-1"}
|
||||
|
||||
def test_passthrough_when_prompt_id_unknown(self, registry, lock):
|
||||
merged = merge_prompt_metadata(registry, lock, {"node": "n1", "prompt_id": "missing"})
|
||||
assert merged == {"node": "n1", "prompt_id": "missing"}
|
||||
|
||||
def test_passthrough_when_no_prompt_id(self, registry, lock):
|
||||
registry["p1"] = {"workflow_id": "wf-1"}
|
||||
merged = merge_prompt_metadata(registry, lock, {"status": {"queue_remaining": 0}})
|
||||
assert merged == {"status": {"queue_remaining": 0}}
|
||||
|
||||
def test_passthrough_for_non_dict_payload(self, registry, lock):
|
||||
registry["p1"] = {"workflow_id": "wf-1"}
|
||||
binary = (b"image-bytes", {"prompt_id": "p1"})
|
||||
assert merge_prompt_metadata(registry, lock, binary) is binary
|
||||
|
||||
def test_event_payload_wins_over_registered_metadata(self, registry, lock):
|
||||
registry["p1"] = {"workflow_id": "wf-registered"}
|
||||
merged = merge_prompt_metadata(
|
||||
registry, lock, {"prompt_id": "p1", "workflow_id": "wf-caller"}
|
||||
)
|
||||
assert merged["workflow_id"] == "wf-caller"
|
||||
|
||||
|
||||
class TestRaceRegressionForTerminalExecutingFrame:
|
||||
"""Regression for the PR #13684 finally-clear race.
|
||||
|
||||
In the reverted PR, the executor's ``finally`` cleared ``last_workflow_id``
|
||||
before ``main.py`` emitted the post-completion ``executing: {node: None}``
|
||||
frame — so that terminal frame shipped ``workflow_id=None``.
|
||||
|
||||
With the registry approach, ``main.py`` unregisters *after* the terminal
|
||||
send, so the merge sees the registered metadata. This test simulates that
|
||||
ordering to lock in the contract.
|
||||
"""
|
||||
|
||||
def test_terminal_executing_frame_includes_workflow_id(self, registry, lock):
|
||||
registry["p1"] = {"workflow_id": "wf-1"}
|
||||
|
||||
# main.py emits the terminal frame BEFORE unregistering.
|
||||
terminal = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"})
|
||||
registry.pop("p1", None) # main.py's finally: unregister_prompt_metadata
|
||||
|
||||
assert terminal == {"node": None, "prompt_id": "p1", "workflow_id": "wf-1"}
|
||||
|
||||
# After unregister, any straggler events emitted by extensions after
|
||||
# completion are no longer decorated. Verifies the registry is actually
|
||||
# released, not just shadowed.
|
||||
straggler = merge_prompt_metadata(registry, lock, {"node": None, "prompt_id": "p1"})
|
||||
assert "workflow_id" not in straggler
|
||||
Loading…
Reference in New Issue
Block a user