mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Add workflow_id to all websocket messages
This commit addresses BE-672 by ensuring all execution-related websocket messages include the workflow_id field when available. Changes: - Added extract_workflow_id() helper function in comfy_execution/jobs.py to extract workflow_id from extra_data - Updated execution.py to include workflow_id in all websocket messages: - execution_start - execution_cached - execution_success - execution_error - execution_interrupted - executing - executed (including cached UI) - Updated main.py to include workflow_id in: - progress messages (via hijack_progress hook) - final executing message (node=None) - Updated comfy_execution/progress.py to include workflow_id in: - progress_state messages - preview image metadata The workflow_id is extracted from extra_data['extra_pnginfo']['workflow']['id'] and is conditionally included in messages only when present, maintaining backward compatibility with workflows that don't have this field. Fixes: BE-672 Co-authored-by: Luke Mino-Altherr <luke-mino-altherr@users.noreply.github.com>
This commit is contained in:
parent
e9c311b245
commit
7578d1989f
@ -105,6 +105,21 @@ def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str
|
|||||||
return create_time, workflow_id
|
return create_time, workflow_id
|
||||||
|
|
||||||
|
|
||||||
|
def extract_workflow_id(extra_data: dict) -> Optional[str]:
|
||||||
|
"""Extract workflow_id from extra_data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extra_data: The extra_data dict containing workflow information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The workflow_id if present, otherwise None
|
||||||
|
"""
|
||||||
|
if not extra_data:
|
||||||
|
return None
|
||||||
|
extra_pnginfo = extra_data.get('extra_pnginfo', {})
|
||||||
|
return extra_pnginfo.get('workflow', {}).get('id')
|
||||||
|
|
||||||
|
|
||||||
def is_previewable(media_type: str, item: dict) -> bool:
|
def is_previewable(media_type: str, item: dict) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an output item is previewable.
|
Check if an output item is previewable.
|
||||||
|
|||||||
@ -182,8 +182,11 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
|
|
||||||
# Send a combined progress_state message with all node states
|
# Send a combined progress_state message with all node states
|
||||||
# Include client_id to ensure message is only sent to the initiating client
|
# Include client_id to ensure message is only sent to the initiating client
|
||||||
|
message = {"prompt_id": prompt_id, "nodes": active_nodes}
|
||||||
|
if self.registry.workflow_id is not None:
|
||||||
|
message["workflow_id"] = self.registry.workflow_id
|
||||||
self.server_instance.send_sync(
|
self.server_instance.send_sync(
|
||||||
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
|
"progress_state", message, self.server_instance.client_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -223,6 +226,8 @@ class WebUIProgressHandler(ProgressHandler):
|
|||||||
),
|
),
|
||||||
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
"real_node_id": self.registry.dynprompt.get_real_node_id(node_id),
|
||||||
}
|
}
|
||||||
|
if self.registry.workflow_id is not None:
|
||||||
|
metadata["workflow_id"] = self.registry.workflow_id
|
||||||
self.server_instance.send_sync(
|
self.server_instance.send_sync(
|
||||||
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
|
BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA,
|
||||||
(image, metadata),
|
(image, metadata),
|
||||||
@ -240,9 +245,10 @@ class ProgressRegistry:
|
|||||||
Registry that maintains node progress state and notifies registered handlers.
|
Registry that maintains node progress state and notifies registered handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"):
|
def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None):
|
||||||
self.prompt_id = prompt_id
|
self.prompt_id = prompt_id
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
|
self.workflow_id = workflow_id
|
||||||
self.nodes: Dict[str, NodeProgressState] = {}
|
self.nodes: Dict[str, NodeProgressState] = {}
|
||||||
self.handlers: Dict[str, ProgressHandler] = {}
|
self.handlers: Dict[str, ProgressHandler] = {}
|
||||||
|
|
||||||
@ -322,7 +328,7 @@ class ProgressRegistry:
|
|||||||
# Global registry instance
|
# Global registry instance
|
||||||
global_progress_registry: ProgressRegistry | None = None
|
global_progress_registry: ProgressRegistry | None = None
|
||||||
|
|
||||||
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt", workflow_id: Optional[str] = None) -> None:
|
||||||
global global_progress_registry
|
global global_progress_registry
|
||||||
|
|
||||||
# Reset existing handlers if registry exists
|
# Reset existing handlers if registry exists
|
||||||
@ -330,7 +336,7 @@ def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None:
|
|||||||
global_progress_registry.reset_handlers()
|
global_progress_registry.reset_handlers()
|
||||||
|
|
||||||
# Create new registry
|
# Create new registry
|
||||||
global_progress_registry = ProgressRegistry(prompt_id, dynprompt)
|
global_progress_registry = ProgressRegistry(prompt_id, dynprompt, workflow_id)
|
||||||
|
|
||||||
|
|
||||||
def add_progress_handler(handler: ProgressHandler) -> None:
|
def add_progress_handler(handler: ProgressHandler) -> None:
|
||||||
|
|||||||
60
execution.py
60
execution.py
@ -41,6 +41,7 @@ from comfy_execution.utils import CurrentNodeContext
|
|||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||||
|
from comfy_execution.jobs import extract_workflow_id
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -416,15 +417,18 @@ def _is_intermediate_output(dynprompt, node_id):
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||||
|
|
||||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs, workflow_id=None):
|
||||||
if server.client_id is None:
|
if server.client_id is None:
|
||||||
return
|
return
|
||||||
cached_ui = cached.ui or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
|
message = { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }
|
||||||
|
if workflow_id is not None:
|
||||||
|
message["workflow_id"] = workflow_id
|
||||||
|
server.send_sync("executed", message, server.client_id)
|
||||||
if cached.ui is not None:
|
if cached.ui is not None:
|
||||||
ui_outputs[node_id] = cached.ui
|
ui_outputs[node_id] = cached.ui
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs, workflow_id=None):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -434,7 +438,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = await caches.outputs.get(unique_id)
|
cached = await caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
|
_send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs, workflow_id)
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, cached)
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
@ -482,7 +486,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
message = { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }
|
||||||
|
if workflow_id is not None:
|
||||||
|
message["workflow_id"] = workflow_id
|
||||||
|
server.send_sync("executing", message, server.client_id)
|
||||||
|
|
||||||
obj = await caches.objects.get(unique_id)
|
obj = await caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
@ -522,6 +529,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"current_inputs": [],
|
"current_inputs": [],
|
||||||
"current_outputs": [],
|
"current_outputs": [],
|
||||||
}
|
}
|
||||||
|
if workflow_id is not None:
|
||||||
|
mes["workflow_id"] = workflow_id
|
||||||
server.send_sync("execution_error", mes, server.client_id)
|
server.send_sync("execution_error", mes, server.client_id)
|
||||||
return ExecutionBlocker(None)
|
return ExecutionBlocker(None)
|
||||||
else:
|
else:
|
||||||
@ -559,7 +568,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"output": output_ui
|
"output": output_ui
|
||||||
}
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
message = { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }
|
||||||
|
if workflow_id is not None:
|
||||||
|
message["workflow_id"] = workflow_id
|
||||||
|
server.send_sync("executed", message, server.client_id)
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
cached_outputs = []
|
cached_outputs = []
|
||||||
new_node_ids = []
|
new_node_ids = []
|
||||||
@ -666,7 +678,7 @@ class PromptExecutor:
|
|||||||
if self.server.client_id is not None or broadcast:
|
if self.server.client_id is not None or broadcast:
|
||||||
self.server.send_sync(event, data, self.server.client_id)
|
self.server.send_sync(event, data, self.server.client_id)
|
||||||
|
|
||||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex, workflow_id=None):
|
||||||
node_id = error["node_id"]
|
node_id = error["node_id"]
|
||||||
class_type = prompt[node_id]["class_type"]
|
class_type = prompt[node_id]["class_type"]
|
||||||
|
|
||||||
@ -679,6 +691,8 @@ class PromptExecutor:
|
|||||||
"node_type": class_type,
|
"node_type": class_type,
|
||||||
"executed": list(executed),
|
"executed": list(executed),
|
||||||
}
|
}
|
||||||
|
if workflow_id is not None:
|
||||||
|
mes["workflow_id"] = workflow_id
|
||||||
self.add_message("execution_interrupted", mes, broadcast=True)
|
self.add_message("execution_interrupted", mes, broadcast=True)
|
||||||
else:
|
else:
|
||||||
mes = {
|
mes = {
|
||||||
@ -692,6 +706,8 @@ class PromptExecutor:
|
|||||||
"current_inputs": error["current_inputs"],
|
"current_inputs": error["current_inputs"],
|
||||||
"current_outputs": list(current_outputs),
|
"current_outputs": list(current_outputs),
|
||||||
}
|
}
|
||||||
|
if workflow_id is not None:
|
||||||
|
mes["workflow_id"] = workflow_id
|
||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||||
@ -720,8 +736,14 @@ class PromptExecutor:
|
|||||||
else:
|
else:
|
||||||
self.server.client_id = None
|
self.server.client_id = None
|
||||||
|
|
||||||
|
# Extract workflow_id from extra_data
|
||||||
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
|
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
execution_start_msg = { "prompt_id": prompt_id }
|
||||||
|
if workflow_id is not None:
|
||||||
|
execution_start_msg["workflow_id"] = workflow_id
|
||||||
|
self.add_message("execution_start", execution_start_msg, broadcast=False)
|
||||||
|
|
||||||
self._notify_prompt_lifecycle("start", prompt_id)
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
|
||||||
@ -731,7 +753,7 @@ class PromptExecutor:
|
|||||||
try:
|
try:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
reset_progress_state(prompt_id, dynamic_prompt, workflow_id)
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
@ -748,9 +770,10 @@ class PromptExecutor:
|
|||||||
]
|
]
|
||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
comfy.model_management.cleanup_models_gc()
|
||||||
self.add_message("execution_cached",
|
execution_cached_msg = { "nodes": cached_nodes, "prompt_id": prompt_id }
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
if workflow_id is not None:
|
||||||
broadcast=False)
|
execution_cached_msg["workflow_id"] = workflow_id
|
||||||
|
self.add_message("execution_cached", execution_cached_msg, broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
ui_node_outputs = {}
|
ui_node_outputs = {}
|
||||||
@ -763,14 +786,14 @@ class PromptExecutor:
|
|||||||
while not execution_list.is_empty():
|
while not execution_list.is_empty():
|
||||||
node_id, error, ex = await execution_list.stage_node_execution()
|
node_id, error, ex = await execution_list.stage_node_execution()
|
||||||
if error is not None:
|
if error is not None:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex, workflow_id)
|
||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs, workflow_id)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex, workflow_id)
|
||||||
break
|
break
|
||||||
elif result == ExecutionResult.PENDING:
|
elif result == ExecutionResult.PENDING:
|
||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
@ -791,8 +814,11 @@ class PromptExecutor:
|
|||||||
cached = await self.caches.outputs.get(node_id)
|
cached = await self.caches.outputs.get(node_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
display_node_id = dynamic_prompt.get_display_node_id(node_id)
|
||||||
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
|
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs, workflow_id)
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
execution_success_msg = { "prompt_id": prompt_id }
|
||||||
|
if workflow_id is not None:
|
||||||
|
execution_success_msg["workflow_id"] = workflow_id
|
||||||
|
self.add_message("execution_success", execution_success_msg, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
|
|||||||
13
main.py
13
main.py
@ -21,6 +21,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
from comfy_execution.progress import get_progress_state
|
from comfy_execution.progress import get_progress_state
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
|
from comfy_execution.jobs import extract_workflow_id
|
||||||
from comfy_api import feature_flags
|
from comfy_api import feature_flags
|
||||||
from app.database.db import init_db, dependencies_available
|
from app.database.db import init_db, dependencies_available
|
||||||
|
|
||||||
@ -322,7 +323,11 @@ def prompt_worker(q, server_instance):
|
|||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages), process_item=remove_sensitive)
|
messages=e.status_messages), process_item=remove_sensitive)
|
||||||
if server_instance.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
workflow_id = extract_workflow_id(extra_data)
|
||||||
|
executing_msg = {"node": None, "prompt_id": prompt_id}
|
||||||
|
if workflow_id is not None:
|
||||||
|
executing_msg["workflow_id"] = workflow_id
|
||||||
|
server_instance.send_sync("executing", executing_msg, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -386,6 +391,12 @@ def hijack_progress(server_instance):
|
|||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = server_instance.last_node_id
|
node_id = server_instance.last_node_id
|
||||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||||
|
|
||||||
|
# Add workflow_id if available from progress state
|
||||||
|
progress_state = get_progress_state()
|
||||||
|
if hasattr(progress_state, 'workflow_id') and progress_state.workflow_id is not None:
|
||||||
|
progress["workflow_id"] = progress_state.workflow_id
|
||||||
|
|
||||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||||
|
|
||||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user