mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Merge 137de03c7b into 8505abf52e
This commit is contained in:
commit
2ff0eb1d25
@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||
"""Extract the workflow id from a prompt's ``extra_data``.
|
||||
|
||||
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
|
||||
when a prompt is queued. Any value that is not a non-empty string is treated as
|
||||
missing so callers can rely on the return being either ``None`` or a string.
|
||||
"""
|
||||
if not isinstance(extra_data, dict):
|
||||
return None
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||
if not isinstance(extra_pnginfo, dict):
|
||||
return None
|
||||
workflow = extra_pnginfo.get('workflow')
|
||||
if not isinstance(workflow, dict):
|
||||
return None
|
||||
workflow_id = workflow.get('id')
|
||||
if isinstance(workflow_id, str) and workflow_id:
|
||||
return workflow_id
|
||||
return None
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
"""Extract create_time and workflow_id from extra_data.
|
||||
|
||||
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
||||
tuple: (create_time, workflow_id)
|
||||
"""
|
||||
create_time = extra_data.get('create_time')
|
||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
return create_time, workflow_id
|
||||
|
||||
|
||||
|
||||
@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
if self.server_instance is None:
|
||||
return
|
||||
|
||||
workflow_id = self.registry.workflow_id if self.registry else None
|
||||
|
||||
# Only send info for non-pending nodes
|
||||
active_nodes = {
|
||||
node_id: {
|
||||
@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
"state": state["state"].value,
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": workflow_id,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||
@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
# Send a combined progress_state message with all node states
|
||||
# Include client_id to ensure message is only sent to the initiating client
|
||||
self.server_instance.send_sync(
|
||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||
)
|
||||
|
||||
@override
|
||||
@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler):
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self.registry.workflow_id if self.registry else None,
|
||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||
node_id
|
||||
),
|
||||
@ -240,9 +244,10 @@ class ProgressRegistry:
|
||||
Registry that maintains node progress state and notifies registered handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
|
||||
self.prompt_id = prompt_id
|
||||
self.dynprompt = dynprompt
|
||||
self.workflow_id = workflow_id
|
||||
self.nodes: Dict[str, NodeProgressState] = {}
|
||||
self.handlers: Dict[str, ProgressHandler] = {}
|
||||
|
||||
@ -322,7 +327,7 @@ class ProgressRegistry:
|
||||
# Global registry instance
|
||||
global_progress_registry: ProgressRegistry | None = None
|
||||
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
|
||||
global global_progress_registry
|
||||
|
||||
# Reset existing handlers if registry exists
|
||||
@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
||||
global_progress_registry.reset_handlers()
|
||||
|
||||
# Create new registry
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
|
||||
|
||||
|
||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||
|
||||
33
execution.py
33
execution.py
@ -38,6 +38,7 @@ from comfy_execution.graph import (
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.jobs import extract_workflow_id
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
@ -417,15 +418,15 @@ def _is_intermediate_output(dynprompt, node_id):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
|
||||
if server.client_id is None:
|
||||
return
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[node_id] = cached.ui
|
||||
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -435,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = await caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
@ -483,7 +484,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||
|
||||
obj = await caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
@ -513,6 +514,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": workflow_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
@ -561,7 +563,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
"output": output_ui
|
||||
}
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
@ -658,6 +660,7 @@ class PromptExecutor:
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
self.workflow_id = None
|
||||
|
||||
def add_message(self, event, data: dict, broadcast: bool):
|
||||
data = {
|
||||
@ -677,6 +680,7 @@ class PromptExecutor:
|
||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
@ -685,6 +689,7 @@ class PromptExecutor:
|
||||
else:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
@ -723,7 +728,9 @@ class PromptExecutor:
|
||||
self.server.client_id = None
|
||||
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
self.workflow_id = extract_workflow_id(extra_data)
|
||||
self.server.last_workflow_id = self.workflow_id
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||
@ -733,7 +740,7 @@ class PromptExecutor:
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
@ -751,7 +758,7 @@ class PromptExecutor:
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
@ -769,7 +776,7 @@ class PromptExecutor:
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@ -793,8 +800,8 @@ class PromptExecutor:
|
||||
cached = await self.caches.outputs.get(node_id)
|
||||
if cached is not None:
|
||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
@ -811,6 +818,8 @@ class PromptExecutor:
|
||||
finally:
|
||||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
self.server.last_workflow_id = None
|
||||
self.workflow_id = None
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
|
||||
11
main.py
11
main.py
@ -29,6 +29,7 @@ import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.jobs import extract_workflow_id
|
||||
from comfy_api import feature_flags
|
||||
from app.database.db import init_db, dependencies_available
|
||||
|
||||
@ -317,6 +318,12 @@ def prompt_worker(q, server_instance):
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
# Capture the workflow id for this prompt before execution: the
|
||||
# executor clears server.last_workflow_id in its finally block, so
|
||||
# reading it after e.execute() returns would emit workflow_id=None
|
||||
# on the terminal "executing" reset below.
|
||||
workflow_id = extract_workflow_id(extra_data)
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
@ -330,7 +337,7 @@ def prompt_worker(q, server_instance):
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": workflow_id}, server_instance.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
@ -393,7 +400,7 @@ def hijack_progress(server_instance):
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
|
||||
@ -275,7 +275,11 @@ class PromptServer():
|
||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||
# On reconnect if we are the currently executing client send the current node
|
||||
if self.client_id == sid and self.last_node_id is not None:
|
||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
||||
await self.send("executing", {
|
||||
"node": self.last_node_id,
|
||||
"prompt_id": getattr(self, "last_prompt_id", None),
|
||||
"workflow_id": getattr(self, "last_workflow_id", None),
|
||||
}, sid)
|
||||
|
||||
# Flag to track if we've received the first message
|
||||
first_message = True
|
||||
|
||||
297
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
297
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads
|
||||
emitted by the progress handler and the prompt executor.
|
||||
|
||||
Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a
|
||||
prompt; we propagate that as `workflow_id` on every execution event so a
|
||||
multi-tab UI can scope progress state by workflow even when terminal
|
||||
WebSocket frames are dropped.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_execution.progress import (
|
||||
NodeState,
|
||||
ProgressRegistry,
|
||||
WebUIProgressHandler,
|
||||
reset_progress_state,
|
||||
get_progress_state,
|
||||
)
|
||||
|
||||
|
||||
class _DummyDynPrompt:
|
||||
def get_display_node_id(self, node_id):
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return None
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
return node_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server():
|
||||
s = MagicMock()
|
||||
s.client_id = "client-1"
|
||||
return s
|
||||
|
||||
|
||||
def _registry(workflow_id):
|
||||
return ProgressRegistry(
|
||||
prompt_id="prompt-1",
|
||||
dynprompt=_DummyDynPrompt(),
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
|
||||
class TestProgressStatePayload:
|
||||
def test_progress_state_includes_workflow_id(self, server):
|
||||
registry = _registry("wf-abc")
|
||||
registry.nodes["n1"] = {
|
||||
"state": NodeState.Running,
|
||||
"value": 1.0,
|
||||
"max": 5.0,
|
||||
}
|
||||
|
||||
handler = WebUIProgressHandler(server)
|
||||
handler.set_registry(registry)
|
||||
handler._send_progress_state("prompt-1", registry.nodes)
|
||||
|
||||
server.send_sync.assert_called_once()
|
||||
event, payload, sid = server.send_sync.call_args.args
|
||||
assert event == "progress_state"
|
||||
assert payload["prompt_id"] == "prompt-1"
|
||||
assert payload["workflow_id"] == "wf-abc"
|
||||
assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc"
|
||||
assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1"
|
||||
assert sid == "client-1"
|
||||
|
||||
def test_progress_state_workflow_id_none_when_missing(self, server):
|
||||
registry = _registry(None)
|
||||
registry.nodes["n1"] = {
|
||||
"state": NodeState.Running,
|
||||
"value": 0.5,
|
||||
"max": 1.0,
|
||||
}
|
||||
|
||||
handler = WebUIProgressHandler(server)
|
||||
handler.set_registry(registry)
|
||||
handler._send_progress_state("prompt-1", registry.nodes)
|
||||
|
||||
_, payload, _ = server.send_sync.call_args.args
|
||||
assert payload["workflow_id"] is None
|
||||
assert payload["nodes"]["n1"]["workflow_id"] is None
|
||||
|
||||
|
||||
class TestProgressRegistryConstruction:
|
||||
def test_workflow_id_default_is_none(self):
|
||||
registry = ProgressRegistry(
|
||||
prompt_id="prompt-1", dynprompt=_DummyDynPrompt()
|
||||
)
|
||||
assert registry.workflow_id is None
|
||||
|
||||
def test_workflow_id_stored_on_registry(self):
|
||||
registry = ProgressRegistry(
|
||||
prompt_id="prompt-1",
|
||||
dynprompt=_DummyDynPrompt(),
|
||||
workflow_id="wf-xyz",
|
||||
)
|
||||
assert registry.workflow_id == "wf-xyz"
|
||||
|
||||
|
||||
class TestResetProgressState:
|
||||
def test_reset_threads_workflow_id(self):
|
||||
reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456")
|
||||
assert get_progress_state().workflow_id == "wf-456"
|
||||
|
||||
def test_reset_default_workflow_id_none(self):
|
||||
reset_progress_state("prompt-2", _DummyDynPrompt())
|
||||
assert get_progress_state().workflow_id is None
|
||||
|
||||
|
||||
class TestExecutionMessagePayloadsContainWorkflowId:
|
||||
"""Static-analysis guard ensuring every WebSocket message payload that
|
||||
carries `prompt_id` also carries `workflow_id`. This is a regression net
|
||||
for future refactors of execution.py / main.py / progress.py and avoids
|
||||
the GPU/torch dependency of importing `execution.py` directly.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _emitting_dicts(source: str):
|
||||
"""Yield every dict literal in `source` that contains a 'prompt_id' key."""
|
||||
import ast
|
||||
|
||||
tree = ast.parse(source)
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Dict):
|
||||
continue
|
||||
keys = [
|
||||
k.value
|
||||
for k in node.keys
|
||||
if isinstance(k, ast.Constant) and isinstance(k.value, str)
|
||||
]
|
||||
if "prompt_id" in keys:
|
||||
yield node, keys
|
||||
|
||||
def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str):
|
||||
from pathlib import Path
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
source = (repo_root / file_path).read_text()
|
||||
offenders = []
|
||||
for node, keys in self._emitting_dicts(source):
|
||||
if "workflow_id" not in keys:
|
||||
offenders.append((node.lineno, keys))
|
||||
assert not offenders, (
|
||||
f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}"
|
||||
)
|
||||
|
||||
def test_execution_py_payloads_include_workflow_id(self):
|
||||
self._assert_workflow_id_in_every_prompt_id_dict("execution.py")
|
||||
|
||||
def test_main_py_payloads_include_workflow_id(self):
|
||||
self._assert_workflow_id_in_every_prompt_id_dict("main.py")
|
||||
|
||||
def test_progress_py_payloads_include_workflow_id(self):
|
||||
self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py")
|
||||
|
||||
|
||||
class TestPreviewImageMetadataPayload:
|
||||
"""Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id."""
|
||||
|
||||
def test_preview_metadata_includes_workflow_id(self):
|
||||
from unittest.mock import MagicMock, patch
|
||||
from PIL import Image
|
||||
|
||||
from comfy_execution.progress import (
|
||||
NodeState,
|
||||
ProgressRegistry,
|
||||
WebUIProgressHandler,
|
||||
)
|
||||
|
||||
class _DynPrompt:
|
||||
def get_display_node_id(self, n):
|
||||
return n
|
||||
|
||||
def get_parent_node_id(self, n):
|
||||
return None
|
||||
|
||||
def get_real_node_id(self, n):
|
||||
return n
|
||||
|
||||
server = MagicMock()
|
||||
server.client_id = "cid"
|
||||
server.sockets_metadata = {}
|
||||
|
||||
registry = ProgressRegistry(
|
||||
prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1"
|
||||
)
|
||||
handler = WebUIProgressHandler(server)
|
||||
handler.set_registry(registry)
|
||||
|
||||
image = ("PNG", Image.new("RGB", (1, 1)), None)
|
||||
|
||||
with patch(
|
||||
"comfy_execution.progress.feature_flags.supports_feature",
|
||||
return_value=True,
|
||||
):
|
||||
handler.update_handler(
|
||||
node_id="n1",
|
||||
value=1.0,
|
||||
max_value=1.0,
|
||||
state={
|
||||
"state": NodeState.Running,
|
||||
"value": 1.0,
|
||||
"max": 1.0,
|
||||
},
|
||||
prompt_id="p1",
|
||||
image=image,
|
||||
)
|
||||
|
||||
preview_calls = [
|
||||
c
|
||||
for c in server.send_sync.call_args_list
|
||||
if c.args[0] != "progress_state"
|
||||
]
|
||||
assert len(preview_calls) == 1
|
||||
_, payload, _ = preview_calls[0].args
|
||||
_, metadata = payload
|
||||
assert metadata["prompt_id"] == "p1"
|
||||
assert metadata["workflow_id"] == "wf-1"
|
||||
|
||||
|
||||
|
||||
class TestTerminalExecutingResetInMainPy:
|
||||
"""Regression test for the main.py prompt_worker terminal 'executing' reset.
|
||||
|
||||
The executor clears server.last_workflow_id in its finally block, so
|
||||
main.py must capture the workflow id *before* calling e.execute() and use
|
||||
that local value, not read server.last_workflow_id afterwards.
|
||||
|
||||
Rather than importing main.py (which triggers torch CUDA init in this
|
||||
environment), we statically assert the contract via AST: somewhere
|
||||
between the `extra_data = item[3].copy()` line and the
|
||||
`e.execute(item[2], ...)` call, the function must extract workflow_id
|
||||
from extra_data into a local, and the subsequent send_sync("executing",
|
||||
...) must reference that local rather than server.last_workflow_id.
|
||||
"""
|
||||
|
||||
def test_terminal_executing_uses_locally_captured_workflow_id(self):
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
source = (repo_root / "main.py").read_text()
|
||||
tree = ast.parse(source)
|
||||
|
||||
worker = next(
|
||||
(
|
||||
n
|
||||
for n in ast.walk(tree)
|
||||
if isinstance(n, ast.FunctionDef) and n.name == "prompt_worker"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert worker is not None, "prompt_worker function not found in main.py"
|
||||
|
||||
worker_src = ast.get_source_segment(source, worker) or ""
|
||||
|
||||
assert "extract_workflow_id(extra_data)" in worker_src, (
|
||||
"main.py:prompt_worker must capture workflow_id locally from extra_data "
|
||||
"before calling e.execute() (the executor clears server.last_workflow_id "
|
||||
"in finally)."
|
||||
)
|
||||
|
||||
matched_terminal_executing_send = False
|
||||
for node in ast.walk(worker):
|
||||
if not isinstance(node, ast.Call):
|
||||
continue
|
||||
func = node.func
|
||||
if not (
|
||||
isinstance(func, ast.Attribute)
|
||||
and func.attr == "send_sync"
|
||||
and node.args
|
||||
and isinstance(node.args[0], ast.Constant)
|
||||
and node.args[0].value == "executing"
|
||||
and len(node.args) >= 2
|
||||
and isinstance(node.args[1], ast.Dict)
|
||||
):
|
||||
continue
|
||||
matched_terminal_executing_send = True
|
||||
payload = node.args[1]
|
||||
for key, value in zip(payload.keys, payload.values):
|
||||
if isinstance(key, ast.Constant) and key.value == "workflow_id":
|
||||
rendered = ast.unparse(value)
|
||||
assert "last_workflow_id" not in rendered, (
|
||||
"main.py terminal 'executing' must not read "
|
||||
"server.last_workflow_id; the executor clears it in its "
|
||||
"finally block. Use a locally captured workflow_id instead."
|
||||
)
|
||||
|
||||
assert matched_terminal_executing_send, (
|
||||
"main.py:prompt_worker no longer has an inline "
|
||||
'send_sync("executing", {...}) payload; update this regression test '
|
||||
"so it still verifies the terminal workflow_id source."
|
||||
)
|
||||
@ -10,9 +10,44 @@ from comfy_execution.jobs import (
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
extract_workflow_id,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractWorkflowId:
|
||||
"""Unit tests for extract_workflow_id()."""
|
||||
|
||||
def test_returns_id_from_extra_pnginfo(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123'
|
||||
|
||||
def test_missing_extra_data_returns_none(self):
|
||||
assert extract_workflow_id(None) is None
|
||||
|
||||
def test_non_dict_extra_data_returns_none(self):
|
||||
assert extract_workflow_id('not-a-dict') is None
|
||||
|
||||
def test_missing_extra_pnginfo_returns_none(self):
|
||||
assert extract_workflow_id({}) is None
|
||||
|
||||
def test_missing_workflow_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {}}) is None
|
||||
|
||||
def test_missing_id_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None
|
||||
|
||||
def test_empty_string_id_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None
|
||||
|
||||
def test_non_string_id_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None
|
||||
|
||||
def test_non_dict_workflow_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None
|
||||
|
||||
def test_non_dict_extra_pnginfo_returns_none(self):
|
||||
assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test JobStatus constants."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user