mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 11:37:32 +08:00
Merge 137de03c7b into 8505abf52e
This commit is contained in:
commit
2ff0eb1d25
@ -93,6 +93,27 @@ def _create_text_preview(value: str) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_workflow_id(extra_data: Optional[dict]) -> Optional[str]:
|
||||||
|
"""Extract the workflow id from a prompt's ``extra_data``.
|
||||||
|
|
||||||
|
The frontend stores the id at ``extra_data["extra_pnginfo"]["workflow"]["id"]``
|
||||||
|
when a prompt is queued. Any value that is not a non-empty string is treated as
|
||||||
|
missing so callers can rely on the return being either ``None`` or a string.
|
||||||
|
"""
|
||||||
|
if not isinstance(extra_data, dict):
|
||||||
|
return None
|
||||||
|
extra_pnginfo = extra_data.get('extra_pnginfo')
|
||||||
|
if not isinstance(extra_pnginfo, dict):
|
||||||
|
return None
|
||||||
|
workflow = extra_pnginfo.get('workflow')
|
||||||
|
if not isinstance(workflow, dict):
|
||||||
|
return None
|
||||||
|
workflow_id = workflow.get('id')
|
||||||
|
if isinstance(workflow_id, str) and workflow_id:
|
||||||
|
return workflow_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
"""Extract create_time and workflow_id from extra_data.
|
"""Extract create_time and workflow_id from extra_data.
|
||||||
|
|
||||||
@ -100,8 +121,7 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
|||||||
tuple: (create_time, workflow_id)
|
tuple: (create_time, workflow_id)
|
||||||
"""
|
"""
|
||||||
create_time = extra_data.get('create_time')
|
create_time = extra_data.get('create_time')
|
||||||
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
workflow_id = extra_pnginfo.get('workflow', {}).get('id')
|
|
||||||
return create_time, workflow_id
|
return create_time, workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -164,6 +164,8 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
if self.server_instance is None:
|
if self.server_instance is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
workflow_id = self.registry.workflow_id if self.registry else None
|
||||||
|
|
||||||
# Only send info for non-pending nodes
|
# Only send info for non-pending nodes
|
||||||
active_nodes = {
|
active_nodes = {
|
||||||
node_id: {
|
node_id: {
|
||||||
@ -172,6 +174,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
"state": state["state"].value,
|
"state": state["state"].value,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": workflow_id,
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
"display_node_id": self.registry.dynprompt.get_display_node_id(node_id),
|
||||||
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
"parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id),
|
||||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||||
@ -183,7 +186,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
# Send a combined progress_state message with all node states
|
# Send a combined progress_state message with all node states
|
||||||
# Include client_id to ensure message is only sent to the initiating client
|
# Include client_id to ensure message is only sent to the initiating client
|
||||||
self.server_instance.send_sync(
|
self.server_instance.send_sync(
|
||||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
"progress_state", {"prompt_id": prompt_id, "workflow_id": workflow_id, "nodes": active_nodes}, self.server_instance.client_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -215,6 +218,7 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
metadata = {
|
metadata = {
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self.registry.workflow_id if self.registry else None,
|
||||||
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
"display_node_id": self.registry.dynprompt.get_display_node_id(
|
||||||
node_id
|
node_id
|
||||||
),
|
),
|
||||||
@ -240,9 +244,10 @@ class ProgressRegistry:
|
|||||||
Registry that maintains node progress state and notifies registered handlers.
|
Registry that maintains node progress state and notifies registered handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
|
||||||
self.prompt_id = prompt_id
|
self.prompt_id = prompt_id
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
|
self.workflow_id = workflow_id
|
||||||
self.nodes: Dict[str, NodeProgressState] = {}
|
self.nodes: Dict[str, NodeProgressState] = {}
|
||||||
self.handlers: Dict[str, ProgressHandler] = {}
|
self.handlers: Dict[str, ProgressHandler] = {}
|
||||||
|
|
||||||
@ -322,7 +327,7 @@ class ProgressRegistry:
|
|||||||
# Global registry instance
|
# Global registry instance
|
||||||
global_progress_registry: ProgressRegistry | None = None
|
global_progress_registry: ProgressRegistry | None = None
|
||||||
|
|
||||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
|
||||||
global global_progress_registry
|
global global_progress_registry
|
||||||
|
|
||||||
# Reset existing handlers if registry exists
|
# Reset existing handlers if registry exists
|
||||||
@ -330,7 +335,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
|||||||
global_progress_registry.reset_handlers()
|
global_progress_registry.reset_handlers()
|
||||||
|
|
||||||
# Create new registry
|
# Create new registry
|
||||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
|
||||||
|
|
||||||
|
|
||||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||||
|
|||||||
33
execution.py
33
execution.py
@ -38,6 +38,7 @@ from comfy_execution.graph import (
|
|||||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy_execution.validation import validate_node_input
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||||
|
from comfy_execution.jobs import extract_workflow_id
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
@ -417,15 +418,15 @@ def _is_intermediate_output(dynprompt, node_id):
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||||
|
|
||||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs):
|
||||||
if server.client_id is None:
|
if server.client_id is None:
|
||||||
return
|
return
|
||||||
cached_ui = cached.ui or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||||
if cached.ui is not None:
|
if cached.ui is not None:
|
||||||
ui_outputs[node_id] = cached.ui
|
ui_outputs[node_id] = cached.ui
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -435,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = await caches.outputs.get(unique_id)
|
cached = await caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, workflow_id, ui_outputs)
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, cached)
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
@ -483,7 +484,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||||
|
|
||||||
obj = await caches.objects.get(unique_id)
|
obj = await caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
@ -513,6 +514,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if block.message is not None:
|
if block.message is not None:
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": workflow_id,
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
@ -561,7 +563,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"output": output_ui
|
"output": output_ui
|
||||||
}
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id, "workflow_id": workflow_id }, server.client_id)
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
cached_outputs = []
|
cached_outputs = []
|
||||||
new_node_ids = []
|
new_node_ids = []
|
||||||
@ -658,6 +660,7 @@ class PromptExecutor:
|
|||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
self.workflow_id = None
|
||||||
|
|
||||||
def add_message(self, event, data: dict, broadcast: bool):
|
def add_message(self, event, data: dict, broadcast: bool):
|
||||||
data = {
|
data = {
|
||||||
@ -677,6 +680,7 @@ class PromptExecutor:
|
|||||||
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
if isinstance(ex, comfy.model_management.InterruptProcessingException):
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self.workflow_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
@ -685,6 +689,7 @@ class PromptExecutor:
|
|||||||
else:
|
else:
|
||||||
mes = {
|
mes = {
|
||||||
"prompt_id": prompt_id,
|
"prompt_id": prompt_id,
|
||||||
|
"workflow_id": self.workflow_id,
|
||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
@ -723,7 +728,9 @@ class PromptExecutor:
|
|||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.workflow_id = extract_workflow_id(extra_data)
|
||||||
|
self.server.last_workflow_id = self.workflow_id
|
||||||
|
self.add_message("execution_start", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||||
|
|
||||||
self._notify_prompt_lifecycle("start", prompt_id)
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||||
@ -733,7 +740,7 @@ class PromptExecutor:
|
|||||||
try:
|
try:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
reset_progress_state(prompt_id, dynamic_prompt, self.workflow_id)
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
@ -751,7 +758,7 @@ class PromptExecutor:
|
|||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
comfy.model_management.cleanup_models_gc()
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
{ "nodes": cached_nodes, "prompt_id": prompt_id, "workflow_id": self.workflow_id },
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
@ -769,7 +776,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, self.workflow_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
@ -793,8 +800,8 @@ class PromptExecutor:
|
|||||||
cached = await self.caches.outputs.get(node_id)
|
cached = await self.caches.outputs.get(node_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, self.workflow_id, ui_node_outputs)
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id, "workflow_id": self.workflow_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
@ -811,6 +818,8 @@ class PromptExecutor:
|
|||||||
finally:
|
finally:
|
||||||
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
comfy.memory_management.set_ram_cache_release_state(None, 0)
|
||||||
self._notify_prompt_lifecycle("end", prompt_id)
|
self._notify_prompt_lifecycle("end", prompt_id)
|
||||||
|
self.server.last_workflow_id = None
|
||||||
|
self.workflow_id = None
|
||||||
|
|
||||||
|
|
||||||
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||||
|
|||||||
11
main.py
11
main.py
@ -29,6 +29,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
from comfy_execution.progress import get_progress_state
|
from comfy_execution.progress import get_progress_state
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
|
from comfy_execution.jobs import extract_workflow_id
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
from app.database.db import init_db, dependencies_available
|
from app.database.db import init_db, dependencies_available
|
||||||
|
|
||||||
@ -317,6 +318,12 @@ def prompt_worker(q, server_instance):
|
|||||||
for k in sensitive:
|
for k in sensitive:
|
||||||
extra_data[k] = sensitive[k]
|
extra_data[k] = sensitive[k]
|
||||||
|
|
||||||
|
# Capture the workflow id for this prompt before execution: the
|
||||||
|
# executor clears server.last_workflow_id in its finally block, so
|
||||||
|
# reading it after e.execute() returns would emit workflow_id=None
|
||||||
|
# on the terminal "executing" reset below.
|
||||||
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
|
|
||||||
asset_seeder.pause()
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
@ -330,7 +337,7 @@ def prompt_worker(q, server_instance):
|
|||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages), process_item=remove_sensitive)
|
messages=e.status_messages), process_item=remove_sensitive)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id, "workflow_id": workflow_id}, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -393,7 +400,7 @@ def hijack_progress(server_instance):
|
|||||||
prompt_id = server_instance.last_prompt_id
|
prompt_id = server_instance.last_prompt_id
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = server_instance.last_node_id
|
node_id = server_instance.last_node_id
|
||||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
progress = {"value": value, "max": total, "prompt_id": prompt_id, "workflow_id": getattr(server_instance, 'last_workflow_id', None), "node": node_id}
|
||||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||||
|
|
||||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
|
|||||||
@ -275,7 +275,11 @@ class PromptServer():
|
|||||||
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
await self.send("status", {"status": self.get_queue_info(), "sid": sid}, sid)
|
||||||
# On reconnect if we are the currently executing client send the current node
|
# On reconnect if we are the currently executing client send the current node
|
||||||
if self.client_id == sid and self.last_node_id is not None:
|
if self.client_id == sid and self.last_node_id is not None:
|
||||||
await self.send("executing", { "node": self.last_node_id }, sid)
|
await self.send("executing", {
|
||||||
|
"node": self.last_node_id,
|
||||||
|
"prompt_id": getattr(self, "last_prompt_id", None),
|
||||||
|
"workflow_id": getattr(self, "last_workflow_id", None),
|
||||||
|
}, sid)
|
||||||
|
|
||||||
# Flag to track if we've received the first message
|
# Flag to track if we've received the first message
|
||||||
first_message = True
|
first_message = True
|
||||||
|
|||||||
297
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
297
tests-unit/execution_test/test_workflow_id_in_ws_messages.py
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
"""Tests that workflow_id is included alongside prompt_id in WebSocket payloads
|
||||||
|
emitted by the progress handler and the prompt executor.
|
||||||
|
|
||||||
|
Frontend stores extra_data["extra_pnginfo"]["workflow"]["id"] when queueing a
|
||||||
|
prompt; we propagate that as `workflow_id` on every execution event so a
|
||||||
|
multi-tab UI can scope progress state by workflow even when terminal
|
||||||
|
WebSocket frames are dropped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy_execution.progress import (
|
||||||
|
NodeState,
|
||||||
|
ProgressRegistry,
|
||||||
|
WebUIProgressHandler,
|
||||||
|
reset_progress_state,
|
||||||
|
get_progress_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyDynPrompt:
|
||||||
|
def get_display_node_id(self, node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def get_parent_node_id(self, node_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_real_node_id(self, node_id):
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def server():
|
||||||
|
s = MagicMock()
|
||||||
|
s.client_id = "client-1"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _registry(workflow_id):
|
||||||
|
return ProgressRegistry(
|
||||||
|
prompt_id="prompt-1",
|
||||||
|
dynprompt=_DummyDynPrompt(),
|
||||||
|
workflow_id=workflow_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressStatePayload:
|
||||||
|
def test_progress_state_includes_workflow_id(self, server):
|
||||||
|
registry = _registry("wf-abc")
|
||||||
|
registry.nodes["n1"] = {
|
||||||
|
"state": NodeState.Running,
|
||||||
|
"value": 1.0,
|
||||||
|
"max": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = WebUIProgressHandler(server)
|
||||||
|
handler.set_registry(registry)
|
||||||
|
handler._send_progress_state("prompt-1", registry.nodes)
|
||||||
|
|
||||||
|
server.send_sync.assert_called_once()
|
||||||
|
event, payload, sid = server.send_sync.call_args.args
|
||||||
|
assert event == "progress_state"
|
||||||
|
assert payload["prompt_id"] == "prompt-1"
|
||||||
|
assert payload["workflow_id"] == "wf-abc"
|
||||||
|
assert payload["nodes"]["n1"]["workflow_id"] == "wf-abc"
|
||||||
|
assert payload["nodes"]["n1"]["prompt_id"] == "prompt-1"
|
||||||
|
assert sid == "client-1"
|
||||||
|
|
||||||
|
def test_progress_state_workflow_id_none_when_missing(self, server):
|
||||||
|
registry = _registry(None)
|
||||||
|
registry.nodes["n1"] = {
|
||||||
|
"state": NodeState.Running,
|
||||||
|
"value": 0.5,
|
||||||
|
"max": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = WebUIProgressHandler(server)
|
||||||
|
handler.set_registry(registry)
|
||||||
|
handler._send_progress_state("prompt-1", registry.nodes)
|
||||||
|
|
||||||
|
_, payload, _ = server.send_sync.call_args.args
|
||||||
|
assert payload["workflow_id"] is None
|
||||||
|
assert payload["nodes"]["n1"]["workflow_id"] is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressRegistryConstruction:
|
||||||
|
def test_workflow_id_default_is_none(self):
|
||||||
|
registry = ProgressRegistry(
|
||||||
|
prompt_id="prompt-1", dynprompt=_DummyDynPrompt()
|
||||||
|
)
|
||||||
|
assert registry.workflow_id is None
|
||||||
|
|
||||||
|
def test_workflow_id_stored_on_registry(self):
|
||||||
|
registry = ProgressRegistry(
|
||||||
|
prompt_id="prompt-1",
|
||||||
|
dynprompt=_DummyDynPrompt(),
|
||||||
|
workflow_id="wf-xyz",
|
||||||
|
)
|
||||||
|
assert registry.workflow_id == "wf-xyz"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResetProgressState:
|
||||||
|
def test_reset_threads_workflow_id(self):
|
||||||
|
reset_progress_state("prompt-1", _DummyDynPrompt(), "wf-456")
|
||||||
|
assert get_progress_state().workflow_id == "wf-456"
|
||||||
|
|
||||||
|
def test_reset_default_workflow_id_none(self):
|
||||||
|
reset_progress_state("prompt-2", _DummyDynPrompt())
|
||||||
|
assert get_progress_state().workflow_id is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutionMessagePayloadsContainWorkflowId:
|
||||||
|
"""Static-analysis guard ensuring every WebSocket message payload that
|
||||||
|
carries `prompt_id` also carries `workflow_id`. This is a regression net
|
||||||
|
for future refactors of execution.py / main.py / progress.py and avoids
|
||||||
|
the GPU/torch dependency of importing `execution.py` directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _emitting_dicts(source: str):
|
||||||
|
"""Yield every dict literal in `source` that contains a 'prompt_id' key."""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
tree = ast.parse(source)
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if not isinstance(node, ast.Dict):
|
||||||
|
continue
|
||||||
|
keys = [
|
||||||
|
k.value
|
||||||
|
for k in node.keys
|
||||||
|
if isinstance(k, ast.Constant) and isinstance(k.value, str)
|
||||||
|
]
|
||||||
|
if "prompt_id" in keys:
|
||||||
|
yield node, keys
|
||||||
|
|
||||||
|
def _assert_workflow_id_in_every_prompt_id_dict(self, file_path: str):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
source = (repo_root / file_path).read_text()
|
||||||
|
offenders = []
|
||||||
|
for node, keys in self._emitting_dicts(source):
|
||||||
|
if "workflow_id" not in keys:
|
||||||
|
offenders.append((node.lineno, keys))
|
||||||
|
assert not offenders, (
|
||||||
|
f"{file_path}: dict literals with 'prompt_id' but no 'workflow_id': {offenders}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_execution_py_payloads_include_workflow_id(self):
|
||||||
|
self._assert_workflow_id_in_every_prompt_id_dict("execution.py")
|
||||||
|
|
||||||
|
def test_main_py_payloads_include_workflow_id(self):
|
||||||
|
self._assert_workflow_id_in_every_prompt_id_dict("main.py")
|
||||||
|
|
||||||
|
def test_progress_py_payloads_include_workflow_id(self):
|
||||||
|
self._assert_workflow_id_in_every_prompt_id_dict("comfy_execution/progress.py")
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreviewImageMetadataPayload:
|
||||||
|
"""Verify PREVIEW_IMAGE_WITH_METADATA metadata carries workflow_id."""
|
||||||
|
|
||||||
|
def test_preview_metadata_includes_workflow_id(self):
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy_execution.progress import (
|
||||||
|
NodeState,
|
||||||
|
ProgressRegistry,
|
||||||
|
WebUIProgressHandler,
|
||||||
|
)
|
||||||
|
|
||||||
|
class _DynPrompt:
|
||||||
|
def get_display_node_id(self, n):
|
||||||
|
return n
|
||||||
|
|
||||||
|
def get_parent_node_id(self, n):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_real_node_id(self, n):
|
||||||
|
return n
|
||||||
|
|
||||||
|
server = MagicMock()
|
||||||
|
server.client_id = "cid"
|
||||||
|
server.sockets_metadata = {}
|
||||||
|
|
||||||
|
registry = ProgressRegistry(
|
||||||
|
prompt_id="p1", dynprompt=_DynPrompt(), workflow_id="wf-1"
|
||||||
|
)
|
||||||
|
handler = WebUIProgressHandler(server)
|
||||||
|
handler.set_registry(registry)
|
||||||
|
|
||||||
|
image = ("PNG", Image.new("RGB", (1, 1)), None)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"comfy_execution.progress.feature_flags.supports_feature",
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
handler.update_handler(
|
||||||
|
node_id="n1",
|
||||||
|
value=1.0,
|
||||||
|
max_value=1.0,
|
||||||
|
state={
|
||||||
|
"state": NodeState.Running,
|
||||||
|
"value": 1.0,
|
||||||
|
"max": 1.0,
|
||||||
|
},
|
||||||
|
prompt_id="p1",
|
||||||
|
image=image,
|
||||||
|
)
|
||||||
|
|
||||||
|
preview_calls = [
|
||||||
|
c
|
||||||
|
for c in server.send_sync.call_args_list
|
||||||
|
if c.args[0] != "progress_state"
|
||||||
|
]
|
||||||
|
assert len(preview_calls) == 1
|
||||||
|
_, payload, _ = preview_calls[0].args
|
||||||
|
_, metadata = payload
|
||||||
|
assert metadata["prompt_id"] == "p1"
|
||||||
|
assert metadata["workflow_id"] == "wf-1"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestTerminalExecutingResetInMainPy:
|
||||||
|
"""Regression test for the main.py prompt_worker terminal 'executing' reset.
|
||||||
|
|
||||||
|
The executor clears server.last_workflow_id in its finally block, so
|
||||||
|
main.py must capture the workflow id *before* calling e.execute() and use
|
||||||
|
that local value, not read server.last_workflow_id afterwards.
|
||||||
|
|
||||||
|
Rather than importing main.py (which triggers torch CUDA init in this
|
||||||
|
environment), we statically assert the contract via AST: somewhere
|
||||||
|
between the `extra_data = item[3].copy()` line and the
|
||||||
|
`e.execute(item[2], ...)` call, the function must extract workflow_id
|
||||||
|
from extra_data into a local, and the subsequent send_sync("executing",
|
||||||
|
...) must reference that local rather than server.last_workflow_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_terminal_executing_uses_locally_captured_workflow_id(self):
|
||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
source = (repo_root / "main.py").read_text()
|
||||||
|
tree = ast.parse(source)
|
||||||
|
|
||||||
|
worker = next(
|
||||||
|
(
|
||||||
|
n
|
||||||
|
for n in ast.walk(tree)
|
||||||
|
if isinstance(n, ast.FunctionDef) and n.name == "prompt_worker"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert worker is not None, "prompt_worker function not found in main.py"
|
||||||
|
|
||||||
|
worker_src = ast.get_source_segment(source, worker) or ""
|
||||||
|
|
||||||
|
assert "extract_workflow_id(extra_data)" in worker_src, (
|
||||||
|
"main.py:prompt_worker must capture workflow_id locally from extra_data "
|
||||||
|
"before calling e.execute() (the executor clears server.last_workflow_id "
|
||||||
|
"in finally)."
|
||||||
|
)
|
||||||
|
|
||||||
|
matched_terminal_executing_send = False
|
||||||
|
for node in ast.walk(worker):
|
||||||
|
if not isinstance(node, ast.Call):
|
||||||
|
continue
|
||||||
|
func = node.func
|
||||||
|
if not (
|
||||||
|
isinstance(func, ast.Attribute)
|
||||||
|
and func.attr == "send_sync"
|
||||||
|
and node.args
|
||||||
|
and isinstance(node.args[0], ast.Constant)
|
||||||
|
and node.args[0].value == "executing"
|
||||||
|
and len(node.args) >= 2
|
||||||
|
and isinstance(node.args[1], ast.Dict)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
matched_terminal_executing_send = True
|
||||||
|
payload = node.args[1]
|
||||||
|
for key, value in zip(payload.keys, payload.values):
|
||||||
|
if isinstance(key, ast.Constant) and key.value == "workflow_id":
|
||||||
|
rendered = ast.unparse(value)
|
||||||
|
assert "last_workflow_id" not in rendered, (
|
||||||
|
"main.py terminal 'executing' must not read "
|
||||||
|
"server.last_workflow_id; the executor clears it in its "
|
||||||
|
"finally block. Use a locally captured workflow_id instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert matched_terminal_executing_send, (
|
||||||
|
"main.py:prompt_worker no longer has an inline "
|
||||||
|
'send_sync("executing", {...}) payload; update this regression test '
|
||||||
|
"so it still verifies the terminal workflow_id source."
|
||||||
|
)
|
||||||
@ -10,9 +10,44 @@ from comfy_execution.jobs import (
|
|||||||
get_outputs_summary,
|
get_outputs_summary,
|
||||||
apply_sorting,
|
apply_sorting,
|
||||||
has_3d_extension,
|
has_3d_extension,
|
||||||
|
extract_workflow_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractWorkflowId:
|
||||||
|
"""Unit tests for extract_workflow_id()."""
|
||||||
|
|
||||||
|
def test_returns_id_from_extra_pnginfo(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 'wf-123'}}}) == 'wf-123'
|
||||||
|
|
||||||
|
def test_missing_extra_data_returns_none(self):
|
||||||
|
assert extract_workflow_id(None) is None
|
||||||
|
|
||||||
|
def test_non_dict_extra_data_returns_none(self):
|
||||||
|
assert extract_workflow_id('not-a-dict') is None
|
||||||
|
|
||||||
|
def test_missing_extra_pnginfo_returns_none(self):
|
||||||
|
assert extract_workflow_id({}) is None
|
||||||
|
|
||||||
|
def test_missing_workflow_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {}}) is None
|
||||||
|
|
||||||
|
def test_missing_id_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {}}}) is None
|
||||||
|
|
||||||
|
def test_empty_string_id_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': ''}}}) is None
|
||||||
|
|
||||||
|
def test_non_string_id_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': {'id': 42}}}) is None
|
||||||
|
|
||||||
|
def test_non_dict_workflow_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': {'workflow': 'not-a-dict'}}) is None
|
||||||
|
|
||||||
|
def test_non_dict_extra_pnginfo_returns_none(self):
|
||||||
|
assert extract_workflow_id({'extra_pnginfo': 'not-a-dict'}) is None
|
||||||
|
|
||||||
|
|
||||||
class TestJobStatus:
|
class TestJobStatus:
|
||||||
"""Test JobStatus constants."""
|
"""Test JobStatus constants."""
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user