mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 23:32:30 +08:00
Include workflow_id in all execution WebSocket messages
The frontend already stores extra_data['extra_pnginfo']['workflow']['id'] when queueing a prompt and exposes it via the /api/jobs REST endpoint, but none of the WebSocket events emitted during execution carry it. That makes it impossible to scope progress state by workflow on the client without maintaining a job_id -> workflow_id mapping that races with execution_start. This adds workflow_id alongside prompt_id on every execution event: - execution_start, execution_success, execution_error, execution_interrupted, execution_cached, executing, executed - progress and progress_state - the metadata block on PREVIEW_IMAGE_WITH_METADATA A new public extract_workflow_id helper in comfy_execution/jobs.py is the single source of truth for the lookup; the existing _extract_job_metadata delegates to it. The id is plumbed through PromptExecutor (stored as self.workflow_id and on server.last_workflow_id), the module-level execute() coroutine, the _send_cached_ui helper, and ProgressRegistry / reset_progress_state so WebUIProgressHandler can include it in progress_state and preview-image metadata. The progress hook in main.py reads server.last_workflow_id to populate the legacy 'progress' event. Tests cover the helper's edge cases (missing/non-string ids, non-dict levels) and that the WebUIProgressHandler emits workflow_id on every progress_state payload via mocked PromptServer.
This commit is contained in:
parent
e9c311b245
commit
2205341279
@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||||
|
"""Extract the workflow id from a prompt's ``extra_data``.
|
||||||
|
|
||||||
|
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
|
||||||
|
when a prompt is queued. Any value that is not a non-empty string is treated as
|
||||||
|
missing so callers can rely on the return being either ``None`` or a string.
|
||||||
|
"""
|
||||||
|
if not isinstance(extra_data, dict):
|
||||||
|
return None
|
||||||
|
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||||
|
if not isinstance(extra_pnginfo, dict):
|
||||||
|
return None
|
||||||
|
workflow = extra_pnginfo.get('workflow')
|
||||||
|
if not isinstance(workflow, dict):
|
||||||
|
return None
|
||||||
|
workflow_id = workflow.get('id')
|
||||||
|
if isinstance(workflow_id, str) and workflow_id:
|
||||||
|
return workflow_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
"""Extract create_time and workflow_id from extra_data.
|
"""Extract create_time and workflow_id from extra_data.
|
||||||
|
|
||||||
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
|||||||
tuple: (create_time, workflow_id)
|
tuple: (create_time, workflow_id)
|
||||||
"""
|
"""
|
||||||
create_time = extra_data.get('create_time')
|
create_time = extra_data.get('create_time')
|
||||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
|
||||||
return create_time, workflow_id
|
return create_time, workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
if self.server_instance is None:
|
if self.server_instance is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
workflow_id = self.registry.workflow_id if self.registry else None
|
||||||
|
|
||||||
# Only send info for non-pending nodes
|
# Only send info for non-pending nodes
|
||||||
active_nodes = {
|
active_nodes = {
|
||||||
node_id: {
|
node_id: {
|
||||||
@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
"state": state["state"].value,
|
"state": state["state"].value,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": workflow_id,
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||||
@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
# Send a combined progress_state message with all node states
|
# Send a combined progress_state message with all node states
|
||||||
# Include client_id to ensure message is only sent to the initiating client
|
# Include client_id to ensure message is only sent to the initiating client
|
||||||
self.server_instance.send_sync(
|
self.server_instance.send_sync(
|
||||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self.registry.workflow_id if self.registry else None,
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||||
node_id
|
node_id
|
||||||
),
|
),
|
||||||
@ -240,9 +244,10 @@ class ProgressRegistry:
|
|||||||
Registry that maintains node progress state and notifies registered handlers.
|
Registry that maintains node progress state and notifies registered handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
|
||||||
self.prompt_id = prompt_id
|
self.prompt_id = prompt_id
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
|
self.workflow_id = workflow_id
|
||||||
self.nodes: Dict[str, NodeProgressState] = {}
|
self.nodes: Dict[str, NodeProgressState] = {}
|
||||||
self.handlers: Dict[str, ProgressHandler] = {}
|
self.handlers: Dict[str, ProgressHandler] = {}
|
||||||
|
|
||||||
@ -322,7 +327,7 @@ class ProgressRegistry:
|
|||||||
# Global registry instance
|
# Global registry instance
|
||||||
global_progress_registry: ProgressRegistry | None = None
|
global_progress_registry: ProgressRegistry | None = None
|
||||||
|
|
||||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
|
||||||
global global_progress_registry
|
global global_progress_registry
|
||||||
|
|
||||||
# Reset existing handlers if registry exists
|
# Reset existing handlers if registry exists
|
||||||
@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
|||||||
global_progress_registry.reset_handlers()
|
global_progress_registry.reset_handlers()
|
||||||
|
|
||||||
# Create new registry
|
# Create new registry
|
||||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
|
||||||
|
|
||||||
|
|
||||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||||
|
|||||||
31
execution.py
31
execution.py
@ -37,6 +37,7 @@ from comfy_execution.graph import (
|
|||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
|
from comfy_execution.jobs import extract_workflow_id
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
@ -416,15 +417,15 @@ def _is_intermediate_output(dynprompt, node_id):
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||||
|
|
||||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
|
||||||
if server.client_id is None:
|
if server.client_id is None:
|
||||||
return
|
return
|
||||||
cached_ui = cached.ui or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||||
if cached.ui is not None:
|
if cached.ui is not None:
|
||||||
ui_outputs[node_id] = cached.ui
|
ui_outputs[node_id] = cached.ui
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -434,7 +435,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = await caches.outputs.get(unique_id)
|
cached = await caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, cached)
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
@ -482,7 +483,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||||
|
|
||||||
obj = await caches.objects.get(unique_id)
|
obj = await caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
@ -512,6 +513,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if block.message is not None:
|
if block.message is not None:
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": workflow_id,
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
@ -559,7 +561,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"output": output_ui
|
"output": output_ui
|
||||||
}
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
cached_outputs = []
|
cached_outputs = []
|
||||||
new_node_ids = []
|
new_node_ids = []
|
||||||
@ -656,6 +658,7 @@ class PromptExecutor:
|
|||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
self.workflow_id = None
|
||||||
|
|
||||||
def add_message(self, event, data: dict, broadcast: bool):
|
def add_message(self, event, data: dict, broadcast: bool):
|
||||||
data = {
|
data = {
|
||||||
@ -675,6 +678,7 @@ class PromptExecutor:
|
|||||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self.workflow_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
@ -683,6 +687,7 @@ class PromptExecutor:
|
|||||||
else:
|
else:
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self.workflow_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
@ -721,7 +726,9 @@ class PromptExecutor:
|
|||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.workflow_id = extract_workflow_id(extra_data)
|
||||||
|
self.server.last_workflow_id = self.workflow_id
|
||||||
|
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||||
|
|
||||||
self._notify_prompt_lifecycle("start", prompt_id)
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||||
@ -731,7 +738,7 @@ class PromptExecutor:
|
|||||||
try:
|
try:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
@ -749,7 +756,7 @@ class PromptExecutor:
|
|||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
comfy.model_management.cleanup_models_gc()
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
@ -767,7 +774,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
@ -791,8 +798,8 @@ class PromptExecutor:
|
|||||||
cached = await self.caches.outputs.get(node_id)
|
cached = await self.caches.outputs.get(node_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
|
|||||||
4
main.py
4
main.py
@ -322,7 +322,7 @@ def prompt_worker(q, server_instance):
|
|||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages), process_item=remove_sensitive)
|
messages=e.status_messages), process_item=remove_sensitive)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None)}, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -385,7 +385,7 @@ def hijack_progress(server_instance):
|
|||||||
prompt_id = server_instance.last_prompt_id
|
prompt_id = server_instance.last_prompt_id
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = server_instance.last_node_id
|
node_id = server_instance.last_node_id
|
||||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
|
||||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||||
|
|
||||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
|
|||||||
111
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
111
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads
|
||||||
|
emitted by the progress handler and the prompt executor.
|
||||||
|
|
||||||
|
Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a
|
||||||
|
prompt; we propagate that as `workflow_id` on every execution event so a
|
||||||
|
multi-tab UI can scope progress state by workflow even when terminal
|
||||||
|
WebSocket frames are dropped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy_execution.progress import (
|
||||||
|
NodeState,
|
||||||
|
ProgressRegistry,
|
||||||
|
WebUIProgressHandler,
|
||||||
|
reset_progress_state,
|
||||||
|
get_progress_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyDynPrompt:
|
||||||
|
def get_display_node_id(self, node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def get_parent_node_id(self, node_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_real_node_id(self, node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server():
|
||||||
|
s = MagicMock()
|
||||||
|
s.client_id = "client-1"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _registry(workflow_id):
|
||||||
|
return ProgressRegistry(
|
||||||
|
prompt_id="prompt-1",
|
||||||
|
dynprompt=_DummyDynPrompt(),
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressStatePayload:
|
||||||
|
def test_progress_state_includes_workflow_id(self, server):
|
||||||
|
registry = _registry("wf-abc")
|
||||||
|
registry.nodes["n1"] = {
|
||||||
|
"state": NodeState.Running,
|
||||||
|
"value": 1.0,
|
||||||
|
"max": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = WebUIProgressHandler(server)
|
||||||
|
handler.set_registry(registry)
|
||||||
|
handler._send_progress_state("prompt-1", registry.nodes)
|
||||||
|
|
||||||
|
server.send_sync.assert_called_once()
|
||||||
|
event, payload, sid = server.send_sync.call_args.args
|
||||||
|
assert event == "progress_state"
|
||||||
|
assert payload["prompt_id"] == "prompt-1"
|
||||||
|
assert payload["workflow_id"] == "wf-abc"
|
||||||
|
assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc"
|
||||||
|
assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1"
|
||||||
|
assert sid == "client-1"
|
||||||
|
|
||||||
|
def test_progress_state_workflow_id_none_when_missing(self, server):
|
||||||
|
registry = _registry(None)
|
||||||
|
registry.nodes["n1"] = {
|
||||||
|
"state": NodeState.Running,
|
||||||
|
"value": 0.5,
|
||||||
|
"max": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = WebUIProgressHandler(server)
|
||||||
|
handler.set_registry(registry)
|
||||||
|
handler._send_progress_state("prompt-1", registry.nodes)
|
||||||
|
|
||||||
|
_, payload, _ = server.send_sync.call_args.args
|
||||||
|
assert payload["workflow_id"] is None
|
||||||
|
assert payload["nodes"]["n1"]["workflow_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressRegistryConstruction:
|
||||||
|
def test_workflow_id_default_is_none(self):
|
||||||
|
registry = ProgressRegistry(
|
||||||
|
prompt_id="prompt-1", dynprompt=_DummyDynPrompt()
|
||||||
|
)
|
||||||
|
assert registry.workflow_id is None
|
||||||
|
|
||||||
|
def test_workflow_id_stored_on_registry(self):
|
||||||
|
registry = ProgressRegistry(
|
||||||
|
prompt_id="prompt-1",
|
||||||
|
dynprompt=_DummyDynPrompt(),
|
||||||
|
workflow_id="wf-xyz",
|
||||||
|
)
|
||||||
|
assert registry.workflow_id == "wf-xyz"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResetProgressState:
|
||||||
|
def test_reset_threads_workflow_id(self):
|
||||||
|
reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456")
|
||||||
|
assert get_progress_state().workflow_id == "wf-456"
|
||||||
|
|
||||||
|
def test_reset_default_workflow_id_none(self):
|
||||||
|
reset_progress_state("prompt-2", _DummyDynPrompt())
|
||||||
|
assert get_progress_state().workflow_id is None
|
||||||
@ -10,9 +10,44 @@ from comfy_execution.jobs import (
|
|||||||
get_outputs_summary,
|
get_outputs_summary,
|
||||||
apply_sorting,
|
apply_sorting,
|
||||||
has_3d_extension,
|
has_3d_extension,
|
||||||
|
extract_workflow_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractWorkflowId:
|
||||||
|
"""Unit tests for extract_workflow_id()."""
|
||||||
|
|
||||||
|
def test_returns_id_from_extra_pnginfo(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123'
|
||||||
|
|
||||||
|
def test_missing_extra_data_returns_none(self):
|
||||||
|
assert extract_workflow_id(None) is None
|
||||||
|
|
||||||
|
def test_non_dict_extra_data_returns_none(self):
|
||||||
|
assert extract_workflow_id('not-a-dict') is None
|
||||||
|
|
||||||
|
def test_missing_extra_pnginfo_returns_none(self):
|
||||||
|
assert extract_workflow_id({}) is None
|
||||||
|
|
||||||
|
def test_missing_workflow_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {}}) is None
|
||||||
|
|
||||||
|
def test_missing_id_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None
|
||||||
|
|
||||||
|
def test_empty_string_id_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None
|
||||||
|
|
||||||
|
def test_non_string_id_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None
|
||||||
|
|
||||||
|
def test_non_dict_workflow_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None
|
||||||
|
|
||||||
|
def test_non_dict_extra_pnginfo_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None
|
||||||
|
|
||||||
|
|
||||||
class TestJobStatus:
|
class TestJobStatus:
|
||||||
"""Test JobStatus constants."""
|
"""Test JobStatus constants."""
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user