This commit is contained in:
Christian Byrne 2026-05-13 05:35:39 -10:00 committed by GitHub
commit 2ff0eb1d25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 398 additions and 21 deletions

View File

@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
}
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
"""Extract the workflow id from a prompt's ``extra_data``.
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
when a prompt is queued. Any value that is not a non-empty string is treated as
missing so callers can rely on the return being either ``None`` or a string.
"""
if not isinstance(extra_data, dict):
return None
extra_pnginfo = extra_data.get('extra_pnginfo')
if not isinstance(extra_pnginfo, dict):
return None
workflow = extra_pnginfo.get('workflow')
if not isinstance(workflow, dict):
return None
workflow_id = workflow.get('id')
if isinstance(workflow_id, str) and workflow_id:
return workflow_id
return None
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
"""Extract create_time and workflow_id from extra_data.
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
tuple: (create_time, workflow_id)
"""
create_time = extra_data.get('create_time')
extra_pnginfo = extra_data.get('extra_pnginfo', {})
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
workflow_id = extract_workflow_id(extra_data)
return create_time, workflow_id

View File

@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler):
if self.server_instance is None:
return
workflow_id = self.registry.workflow_id if self.registry else None
# Only send info for non-pending nodes
active_nodes = {
node_id: {
@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler):
"state": state["state"].value,
"node_id": node_id,
"prompt_id": prompt_id,
"workflow_id": workflow_id,
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler):
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
)
@override
@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler):
metadata = {
"node_id": node_id,
"prompt_id": prompt_id,
"workflow_id": self.registry.workflow_id if self.registry else None,
"display_node_id": self.registry.dynprompt.get_display_node_id(
node_id
),
@ -240,9 +244,10 @@ class ProgressRegistry:
Registry that maintains node progress state and notifies registered handlers.
"""
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
self.prompt_id = prompt_id
self.dynprompt = dynprompt
self.workflow_id = workflow_id
self.nodes: Dict[str, NodeProgressState] = {}
self.handlers: Dict[str, ProgressHandler] = {}
@ -322,7 +327,7 @@ class ProgressRegistry:
# Global registry instance
global_progress_registry: ProgressRegistry | None = None
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
global global_progress_registry
# Reset existing handlers if registry exists
@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
global_progress_registry.reset_handlers()
# Create new registry
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
def add_progress_handler(handler: ProgressHandler) -> None:

View File

@ -38,6 +38,7 @@ from comfy_execution.graph import (
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.jobs import extract_workflow_id
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
@ -417,15 +418,15 @@ def _is_intermediate_output(dynprompt, node_id):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
if server.client_id is None:
return
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -435,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = await caches.outputs.get(unique_id)
if cached is not None:
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
@ -483,7 +484,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
obj = await caches.objects.get(unique_id)
if obj is None:
@ -513,6 +514,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"workflow_id": workflow_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
@ -561,7 +563,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"output": output_ui
}
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
@ -658,6 +660,7 @@ class PromptExecutor:
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
self.success = True
self.workflow_id = None
def add_message(self, event, data: dict, broadcast: bool):
data = {
@ -677,6 +680,7 @@ class PromptExecutor:
if isinstance(ex, comfy.model_management.InterruptProcessingException):
mes = {
"prompt_id": prompt_id,
"workflow_id": self.workflow_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
@ -685,6 +689,7 @@ class PromptExecutor:
else:
mes = {
"prompt_id": prompt_id,
"workflow_id": self.workflow_id,
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
@ -723,7 +728,9 @@ class PromptExecutor:
self.server.client_id = None
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self.workflow_id = extract_workflow_id(extra_data)
self.server.last_workflow_id = self.workflow_id
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
@ -733,7 +740,7 @@ class PromptExecutor:
try:
with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
@ -751,7 +758,7 @@ class PromptExecutor:
comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
@ -769,7 +776,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -793,8 +800,8 @@ class PromptExecutor:
cached = await self.caches.outputs.get(node_id)
if cached is not None:
display_node_id = dynamic_prompt.get_display_node_id(node_id)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
@ -811,6 +818,8 @@ class PromptExecutor:
finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id)
self.server.last_workflow_id = None
self.workflow_id = None
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):

11
main.py
View File

@ -29,6 +29,7 @@ import logging
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_execution.jobs import extract_workflow_id
from comfy_api import feature_flags
from app.database.db import init_db, dependencies_available
@ -317,6 +318,12 @@ def prompt_worker(q, server_instance):
for k in sensitive:
extra_data[k] = sensitive[k]
# Capture the workflow id for this prompt before execution: the
# executor clears server.last_workflow_id in its finally block, so
# reading it after e.execute() returns would emit workflow_id=None
# on the terminal "executing" reset below.
workflow_id = extract_workflow_id(extra_data)
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
@ -330,7 +337,7 @@ def prompt_worker(q, server_instance):
completed=e.success,
messages=e.status_messages), process_item=remove_sensitive)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": workflow_id}, server_instance.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
@ -393,7 +400,7 @@ def hijack_progress(server_instance):
prompt_id = server_instance.last_prompt_id
if node_id is None:
node_id = server_instance.last_node_id
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
get_progress_state().update_progress(node_id, value, total, preview_image)
server_instance.send_sync("progress", progress, server_instance.client_id)

View File

@ -275,7 +275,11 @@ class PromptServer():
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
# On reconnect if we are the currently executing client send the current node
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", { "node": self.last_node_id }, sid)
await self.send("executing", {
"node": self.last_node_id,
"prompt_id": getattr(self, "last_prompt_id", None),
"workflow_id": getattr(self, "last_workflow_id", None),
}, sid)
# Flag to track if we've received the first message
first_message = True

View File

@ -0,0 +1,297 @@
"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads
emitted by the progress handler and the prompt executor.
Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a
prompt; we propagate that as `workflow_id` on every execution event so a
multi-tab UI can scope progress state by workflow even when terminal
WebSocket frames are dropped.
"""
from unittest.mock import MagicMock
import pytest
from comfy_execution.progress import (
NodeState,
ProgressRegistry,
WebUIProgressHandler,
reset_progress_state,
get_progress_state,
)
class _DummyDynPrompt:
def get_display_node_id(self, node_id):
return node_id
def get_parent_node_id(self, node_id):
return None
def get_real_node_id(self, node_id):
return node_id
@pytest.fixture
def server():
s = MagicMock()
s.client_id = "client-1"
return s
def _registry(workflow_id):
return ProgressRegistry(
prompt_id="prompt-1",
dynprompt=_DummyDynPrompt(),
workflow_id=workflow_id,
)
class TestProgressStatePayload:
def test_progress_state_includes_workflow_id(self, server):
registry = _registry("wf-abc")
registry.nodes["n1"] = {
"state": NodeState.Running,
"value": 1.0,
"max": 5.0,
}
handler = WebUIProgressHandler(server)
handler.set_registry(registry)
handler._send_progress_state("prompt-1", registry.nodes)
server.send_sync.assert_called_once()
event, payload, sid = server.send_sync.call_args.args
assert event == "progress_state"
assert payload["prompt_id"] == "prompt-1"
assert payload["workflow_id"] == "wf-abc"
assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc"
assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1"
assert sid == "client-1"
def test_progress_state_workflow_id_none_when_missing(self, server):
registry = _registry(None)
registry.nodes["n1"] = {
"state": NodeState.Running,
"value": 0.5,
"max": 1.0,
}
handler = WebUIProgressHandler(server)
handler.set_registry(registry)
handler._send_progress_state("prompt-1", registry.nodes)
_, payload, _ = server.send_sync.call_args.args
assert payload["workflow_id"] is None
assert payload["nodes"]["n1"]["workflow_id"] is None
class TestProgressRegistryConstruction:
def test_workflow_id_default_is_none(self):
registry = ProgressRegistry(
prompt_id="prompt-1", dynprompt=_DummyDynPrompt()
)
assert registry.workflow_id is None
def test_workflow_id_stored_on_registry(self):
registry = ProgressRegistry(
prompt_id="prompt-1",
dynprompt=_DummyDynPrompt(),
workflow_id="wf-xyz",
)
assert registry.workflow_id == "wf-xyz"
class TestResetProgressState:
def test_reset_threads_workflow_id(self):
reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456")
assert get_progress_state().workflow_id == "wf-456"
def test_reset_default_workflow_id_none(self):
reset_progress_state("prompt-2", _DummyDynPrompt())
assert get_progress_state().workflow_id is None
class TestExecutionMessagePayloadsContainWorkflowId:
"""Static-analysis guard ensuring every WebSocket message payload that
carries `prompt_id` also carries `workflow_id`. This is a regression net
for future refactors of execution.py / main.py / progress.py and avoids
the GPU/torch dependency of importing `execution.py` directly.
"""
@staticmethod
def _emitting_dicts(source: str):
"""Yield every dict literal in `source` that contains a 'prompt_id' key."""
import ast
tree = ast.parse(source)
for node in ast.walk(tree):
if not isinstance(node, ast.Dict):
continue
keys = [
k.value
for k in node.keys
if isinstance(k, ast.Constant) and isinstance(k.value, str)
]
if "prompt_id" in keys:
yield node, keys
def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str):
from pathlib import Path
repo_root = Path(__file__).resolve().parents[2]
source = (repo_root / file_path).read_text()
offenders = []
for node, keys in self._emitting_dicts(source):
if "workflow_id" not in keys:
offenders.append((node.lineno, keys))
assert not offenders, (
f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}"
)
def test_execution_py_payloads_include_workflow_id(self):
self._assert_workflow_id_in_every_prompt_id_dict("execution.py")
def test_main_py_payloads_include_workflow_id(self):
self._assert_workflow_id_in_every_prompt_id_dict("main.py")
def test_progress_py_payloads_include_workflow_id(self):
self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py")
class TestPreviewImageMetadataPayload:
"""Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id."""
def test_preview_metadata_includes_workflow_id(self):
from unittest.mock import MagicMock, patch
from PIL import Image
from comfy_execution.progress import (
NodeState,
ProgressRegistry,
WebUIProgressHandler,
)
class _DynPrompt:
def get_display_node_id(self, n):
return n
def get_parent_node_id(self, n):
return None
def get_real_node_id(self, n):
return n
server = MagicMock()
server.client_id = "cid"
server.sockets_metadata = {}
registry = ProgressRegistry(
prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1"
)
handler = WebUIProgressHandler(server)
handler.set_registry(registry)
image = ("PNG", Image.new("RGB", (1, 1)), None)
with patch(
"comfy_execution.progress.feature_flags.supports_feature",
return_value=True,
):
handler.update_handler(
node_id="n1",
value=1.0,
max_value=1.0,
state={
"state": NodeState.Running,
"value": 1.0,
"max": 1.0,
},
prompt_id="p1",
image=image,
)
preview_calls = [
c
for c in server.send_sync.call_args_list
if c.args[0] != "progress_state"
]
assert len(preview_calls) == 1
_, payload, _ = preview_calls[0].args
_, metadata = payload
assert metadata["prompt_id"] == "p1"
assert metadata["workflow_id"] == "wf-1"
class TestTerminalExecutingResetInMainPy:
"""Regression test for the main.py prompt_worker terminal 'executing' reset.
The executor clears server.last_workflow_id in its finally block, so
main.py must capture the workflow id *before* calling e.execute() and use
that local value, not read server.last_workflow_id afterwards.
Rather than importing main.py (which triggers torch CUDA init in this
environment), we statically assert the contract via AST: somewhere
between the `extra_data = item[3].copy()` line and the
`e.execute(item[2], ...)` call, the function must extract workflow_id
from extra_data into a local, and the subsequent send_sync("executing",
...) must reference that local rather than server.last_workflow_id.
"""
def test_terminal_executing_uses_locally_captured_workflow_id(self):
import ast
from pathlib import Path
repo_root = Path(__file__).resolve().parents[2]
source = (repo_root / "main.py").read_text()
tree = ast.parse(source)
worker = next(
(
n
for n in ast.walk(tree)
if isinstance(n, ast.FunctionDef) and n.name == "prompt_worker"
),
None,
)
assert worker is not None, "prompt_worker function not found in main.py"
worker_src = ast.get_source_segment(source, worker) or ""
assert "extract_workflow_id(extra_data)" in worker_src, (
"main.py:prompt_worker must capture workflow_id locally from extra_data "
"before calling e.execute() (the executor clears server.last_workflow_id "
"in finally)."
)
matched_terminal_executing_send = False
for node in ast.walk(worker):
if not isinstance(node, ast.Call):
continue
func = node.func
if not (
isinstance(func, ast.Attribute)
and func.attr == "send_sync"
and node.args
and isinstance(node.args[0], ast.Constant)
and node.args[0].value == "executing"
and len(node.args) >= 2
and isinstance(node.args[1], ast.Dict)
):
continue
matched_terminal_executing_send = True
payload = node.args[1]
for key, value in zip(payload.keys, payload.values):
if isinstance(key, ast.Constant) and key.value == "workflow_id":
rendered = ast.unparse(value)
assert "last_workflow_id" not in rendered, (
"main.py terminal 'executing' must not read "
"server.last_workflow_id; the executor clears it in its "
"finally block. Use a locally captured workflow_id instead."
)
assert matched_terminal_executing_send, (
"main.py:prompt_worker no longer has an inline "
'send_sync("executing", {...}) payload; update this regression test '
"so it still verifies the terminal workflow_id source."
)

View File

@ -10,9 +10,44 @@ from comfy_execution.jobs import (
get_outputs_summary,
apply_sorting,
has_3d_extension,
extract_workflow_id,
)
class TestExtractWorkflowId:
"""Unit tests for extract_workflow_id()."""
def test_returns_id_from_extra_pnginfo(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123'
def test_missing_extra_data_returns_none(self):
assert extract_workflow_id(None) is None
def test_non_dict_extra_data_returns_none(self):
assert extract_workflow_id('not-a-dict') is None
def test_missing_extra_pnginfo_returns_none(self):
assert extract_workflow_id({}) is None
def test_missing_workflow_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {}}) is None
def test_missing_id_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None
def test_empty_string_id_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None
def test_non_string_id_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None
def test_non_dict_workflow_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None
def test_non_dict_extra_pnginfo_returns_none(self):
assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None
class TestJobStatus:
"""Test JobStatus constants."""