from __future__ import annotations from abc import ABC from enum import Enum from typing import TYPE_CHECKING from typing import TypedDict, Dict, Optional, Tuple from PIL import Image from tqdm import tqdm from typing_extensions import override from comfy.component_model.executor_types import ExecutorToClientProgress, PreviewImageWithMetadataMessage from comfy.component_model.module_property import create_module_properties from comfy.execution_context import current_execution_context from comfy.progress_types import AbstractProgressRegistry, PreviewImageMetadata if TYPE_CHECKING: from comfy_execution.graph import DynamicPrompt from comfy.cmd.protocol import BinaryEventTypes from comfy_api import feature_flags _module_properties = create_module_properties() PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] class NodeState(Enum): Pending = "pending" Running = "running" Finished = "finished" Error = "error" class NodeProgressState(TypedDict): """ A class to represent the state of a node's progress. """ state: NodeState value: float max: float class ProgressHandler(ABC): """ Abstract base class for progress handlers. Progress handlers receive progress updates and display them in various ways. """ def __init__(self, name: str): self.name = name self.enabled = True def set_registry(self, registry: "ProgressRegistry"): pass def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): """Called when a node starts processing""" pass def update_handler( self, node_id: str, value: float, max_value: float, state: NodeProgressState, prompt_id: str, image: PreviewImageTuple | dict | None = None, ): """Called when a node's progress is updated""" pass def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): """Called when a node finishes processing""" pass def reset(self): """Called when the progress registry is reset""" pass def enable(self): """Enable this handler""" self.enabled = True def disable(self): """Disable this handler""" self.enabled = False class CLIProgressHandler(ProgressHandler): """ Handler that displays progress using tqdm progress bars in the CLI. """ def __init__(self): super().__init__("cli") self.progress_bars: Dict[str, tqdm] = {} @override def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Create a new tqdm progress bar if node_id not in self.progress_bars: self.progress_bars[node_id] = tqdm( total=state["max"], desc=f"Node {node_id}", unit="steps", leave=True, position=len(self.progress_bars), ) @override def update_handler( self, node_id: str, value: float, max_value: float, state: NodeProgressState, prompt_id: str, image: PreviewImageTuple | None = None, ): # Handle case where start_handler wasn't called if node_id not in self.progress_bars: self.progress_bars[node_id] = tqdm( total=max_value, desc=f"Node {node_id}", unit="steps", leave=True, position=len(self.progress_bars), ) self.progress_bars[node_id].update(value) else: # Update existing progress bar if max_value != self.progress_bars[node_id].total: self.progress_bars[node_id].total = max_value # Calculate the update amount (difference from current position) current_position = self.progress_bars[node_id].n update_amount = value - current_position if update_amount > 0: self.progress_bars[node_id].update(update_amount) @override def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Complete and close the progress bar if it exists if node_id in self.progress_bars: # Ensure the bar shows 100% completion remaining = state["max"] - self.progress_bars[node_id].n if remaining > 0: self.progress_bars[node_id].update(remaining) self.progress_bars[node_id].close() del self.progress_bars[node_id] @override def reset(self): # Close all progress bars for bar in self.progress_bars.values(): bar.close() self.progress_bars.clear() class WebUIProgressHandler(ProgressHandler): """ Handler that sends progress updates to the WebUI via WebSockets. """ def __init__(self, server_instance: ExecutorToClientProgress): super().__init__("webui") self.server_instance = server_instance def set_registry(self, registry: "ProgressRegistry"): self.registry = registry def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): """Send the current progress state to the client""" if self.server_instance is None: return # Only send info for non-pending nodes active_nodes = { node_id: { "value": state["value"], "max": state["max"], "state": state["state"].value, "node_id": node_id, "prompt_id": prompt_id, "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), } for node_id, state in nodes.items() if state["state"] != NodeState.Pending } # Send a combined progress_state message with all node states # Include client_id to ensure message is only sent to the initiating client self.server_instance.send_sync( "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id ) @override def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Send progress state of all nodes if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) @override def update_handler( self, node_id: str, value: float, max_value: float, state: NodeProgressState, prompt_id: str, image: PreviewImageTuple | dict | None = None, ): # Send progress state of all nodes if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) if image: # Only send new format if client supports it if feature_flags.supports_feature( self.server_instance.sockets_metadata, self.server_instance.client_id, "supports_preview_metadata", ): metadata: PreviewImageMetadata = { "node_id": node_id, "prompt_id": prompt_id, "display_node_id": self.registry.dynprompt.get_display_node_id( node_id ), "parent_node_id": self.registry.dynprompt.get_parent_node_id( node_id ), "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), } # todo: image can be a dict because of transformers loader, do we have to deal with this specially? things just work message: PreviewImageWithMetadataMessage = (image, metadata) self.server_instance.send_sync( BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, message, self.server_instance.client_id, ) @override def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): # Send progress state of all nodes if self.registry: self._send_progress_state(prompt_id, self.registry.nodes) class ProgressRegistry(AbstractProgressRegistry): """ Registry that maintains node progress state and notifies registered handlers. """ def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): self.prompt_id = prompt_id self.dynprompt = dynprompt self.nodes: Dict[str, NodeProgressState] = {} self.handlers: Dict[str, ProgressHandler] = {} def register_handler(self, handler: ProgressHandler) -> None: """Register a progress handler""" self.handlers[handler.name] = handler def unregister_handler(self, handler_name: str) -> None: """Unregister a progress handler""" if handler_name in self.handlers: # Allow handler to clean up resources self.handlers[handler_name].reset() del self.handlers[handler_name] def enable_handler(self, handler_name: str) -> None: """Enable a progress handler""" if handler_name in self.handlers: self.handlers[handler_name].enable() def disable_handler(self, handler_name: str) -> None: """Disable a progress handler""" if handler_name in self.handlers: self.handlers[handler_name].disable() def ensure_entry(self, node_id: str) -> NodeProgressState: """Ensure a node entry exists""" if node_id not in self.nodes: self.nodes[node_id] = NodeProgressState( state=NodeState.Pending, value=0, max=1 ) return self.nodes[node_id] def start_progress(self, node_id: str) -> None: """Start progress tracking for a node""" entry = self.ensure_entry(node_id) entry["state"] = NodeState.Running entry["value"] = 0.0 entry["max"] = 1.0 # Notify all enabled handlers for handler in self.handlers.values(): if handler.enabled: handler.start_handler(node_id, entry, self.prompt_id) def update_progress( self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | dict | None = None ) -> None: """Update progress for a node""" entry = self.ensure_entry(node_id) entry["state"] = NodeState.Running entry["value"] = value entry["max"] = max_value # Notify all enabled handlers for handler in self.handlers.values(): if handler.enabled: handler.update_handler( node_id, value, max_value, entry, self.prompt_id, image ) def finish_progress(self, node_id: str) -> None: """Finish progress tracking for a node""" entry = self.ensure_entry(node_id) entry["state"] = NodeState.Finished entry["value"] = entry["max"] # Notify all enabled handlers for handler in self.handlers.values(): if handler.enabled: handler.finish_handler(node_id, entry, self.prompt_id) def reset_handlers(self) -> None: """Reset all handlers""" for handler in self.handlers.values(): handler.reset() # Global registry instance @_module_properties.getter def _global_progress_registry() -> AbstractProgressRegistry | None: return current_execution_context().progress_registry def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: """ the caller must create a new progress registry :param prompt_id: :param dynprompt: :return: None """ global_progress_registry = _global_progress_registry() # Reset existing handlers if registry exists if global_progress_registry is not None: global_progress_registry.reset_handlers() # XXX caller now creates new progress registry def add_progress_handler(handler: ProgressHandler) -> None: registry = get_progress_state() handler.set_registry(registry) registry.register_handler(handler) def get_progress_state() -> ProgressRegistry: return _global_progress_registry()