diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index ede682a3d..27c852b39 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -24,7 +24,7 @@ from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \ - HistoryResultDict + HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage from ..component_model.files import canonicalize_path from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import new_execution_context, context_execute_node, ExecutionContext @@ -311,6 +311,7 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, with context_execute_node(_node_id, prompt_id): return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) @@ -368,11 +369,11 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) - return (ExecutionResult.PENDING, None, None) + return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) def execution_block_cb(block): if block.message is not None: - mes = { + mes: ExecutionErrorMessage = { "prompt_id": prompt_id, "node_id": unique_id, "node_type": class_type, @@ -442,7 +443,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs - return (ExecutionResult.PENDING, None, None) + return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) caches.outputs.set(unique_id, output_data) except interruption.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -481,7 +482,7 @@ def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, executed.add(unique_id) - return ExecutionResult.SUCCESS, None, None + return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None) class PromptExecutor: @@ -519,7 +520,7 @@ class PromptExecutor: # First, send back the status to the frontend depending # on the exception type if isinstance(ex, interruption.InterruptProcessingException): - mes = { + mes: ExecutionInterruptedMessage = { "prompt_id": prompt_id, "node_id": node_id, "node_type": class_type, @@ -527,7 +528,7 @@ class PromptExecutor: } self.add_message("execution_interrupted", mes, broadcast=True) else: - mes = { + mes: ExecutionErrorMessage = { "prompt_id": prompt_id, "node_id": node_id, "node_type": class_type, @@ -544,10 +545,13 @@ class PromptExecutor: raise ex def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None): - with new_execution_context(ExecutionContext(self.server, task_id=prompt_id)): + # torchao and potentially other optimization approaches break when the models are created in inference mode + # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here + inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) + with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)): self._execute_inner(prompt, prompt_id, extra_data, execute_outputs) - def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None): + def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True): if execute_outputs is None: execute_outputs = [] if extra_data is None: @@ -562,7 +566,7 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext(): + with torch.inference_mode() if inference_mode else nullcontext(): dynamic_prompt = DynamicPrompt(prompt) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) for cache in self.caches.all: diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 113dd9cff..32bc1d22d 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -50,6 +50,13 @@ class UnencodedPreviewImageMessage(NamedTuple): task_id: str = "" +class ExecutionInterruptedMessage(TypedDict): + prompt_id: str + node_id: str + node_type: str + executed: list[str] + + class ExecutionErrorMessage(TypedDict): prompt_id: str node_id: str @@ -74,7 +81,7 @@ ExecutedMessage = ExecutingMessage SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] -SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] +SendSyncData = Union[StatusMessage, ExecutingMessage, DependencyExecutionErrorMessage, ExecutionErrorMessage, ExecutionInterruptedMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] class ExecutorToClientProgress(Protocol): diff --git a/comfy/distributed/process_pool_executor.py b/comfy/distributed/process_pool_executor.py index 9c1b7776a..504650bb2 100644 --- a/comfy/distributed/process_pool_executor.py +++ b/comfy/distributed/process_pool_executor.py @@ -22,7 +22,7 @@ class ProcessPoolExecutor(ProcessPool, Executor): def schedule(self, function: Callable, args: list = (), - kwargs: dict = {}, + kwargs=None, timeout: float = None) -> ProcessFuture: # todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different # try: @@ -31,6 +31,8 @@ class ProcessPoolExecutor(ProcessPool, Executor): # # except ValueError: # pass + if kwargs is None: + kwargs = {} return super().schedule(function, args, kwargs, timeout) def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: diff --git a/comfy/execution_context.py b/comfy/execution_context.py index f8f631deb..bec5c5b27 100644 --- a/comfy/execution_context.py +++ b/comfy/execution_context.py @@ -2,21 +2,24 @@ from __future__ import annotations from contextlib import contextmanager from contextvars import ContextVar -from typing import NamedTuple, Optional +from dataclasses import dataclass, replace +from typing import Optional, Final from .component_model.executor_types import ExecutorToClientProgress from .distributed.server_stub import ServerStub -_current_context = ContextVar("comfyui_execution_context") +_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context") -class ExecutionContext(NamedTuple): +@dataclass(frozen=True) +class ExecutionContext: server: ExecutorToClientProgress node_id: Optional[str] = None task_id: Optional[str] = None + inference_mode: bool = True -_empty_execution_context = ExecutionContext(ServerStub()) +_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub()) def current_execution_context() -> ExecutionContext: @@ -30,7 +33,7 @@ def current_execution_context() -> ExecutionContext: def new_execution_context(ctx: ExecutionContext): token = _current_context.set(ctx) try: - yield + yield ctx finally: _current_context.reset(token) @@ -38,6 +41,6 @@ def new_execution_context(ctx: ExecutionContext): @contextmanager def context_execute_node(node_id: str, prompt_id: str): current_ctx = current_execution_context() - new_ctx = ExecutionContext(current_ctx.server, node_id, prompt_id) + new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id) with new_execution_context(new_ctx): - yield + yield new_ctx diff --git a/comfy/ops.py b/comfy/ops.py index ae6225077..789527a5c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,7 @@ import torch from . import model_management from .cli_args import args +from .execution_context import current_execution_context def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): @@ -352,7 +353,7 @@ class fp8_ops(manual_cast): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, inference_mode: Optional[bool] = None): if inference_mode is None: # todo: check a context here, since this isn't being used by any callers yet - inference_mode = False + inference_mode = current_execution_context().inference_mode if compute_dtype is None or weight_dtype == compute_dtype: # disable_weight_init seems to interact poorly with some other optimization code return disable_weight_init if inference_mode else skip_init