Better support for process pool executors

- --panics-when=torch.cuda.OutOfMemory will now correctly panic and
   exit the worker, giving it time to reply that the execution failed
   and better dealing with irrecoverable out-of-memory errors
 - --executor-factory=ProcessPoolExecutor will use a process instead of
   a thread to execute comfyui workflows when using the worker. When
   this process panics and exits, it will be correctly replaced, making
   a more robust worker
This commit is contained in:
doctorpangloss 2025-02-18 14:37:20 -08:00
parent 684d180446
commit 3ddec8ae90
14 changed files with 458 additions and 83 deletions

View File

@ -1,9 +1,9 @@
import itertools import itertools
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from .cmd.execution import nodes
from .graph import DynamicPrompt from .graph import DynamicPrompt
from .graph_utils import is_link from .graph_utils import is_link
from .nodes_context import get_nodes
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
@ -11,7 +11,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool: def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
@ -110,7 +110,7 @@ class CacheKeySetInputSignature(CacheKeySet):
return [float("NaN")] return [float("NaN")]
node = dynprompt.get_node(node_id) node = dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)

View File

@ -14,13 +14,13 @@ from contextlib import nullcontext
from os import PathLike from os import PathLike
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import lazy_object_proxy
import torch import torch
from opentelemetry.trace import get_current_span, StatusCode, Status from opentelemetry.trace import get_current_span, StatusCode, Status
from .main_pre import tracer from .main_pre import tracer
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from ..cli_args import args from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
@ -28,22 +28,24 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \ RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import context_execute_node, context_execute_prompt from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception from ..execution_ext import should_panic_on_exception
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
# order matters # order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes
from ..validation import validate_node_input from ..validation import validate_node_input
_module_properties = create_module_properties()
@_module_properties.getter
def _nodes():
return get_nodes()
class IsChangedCache: class IsChangedCache:
def __init__(self, dynprompt, outputs_cache): def __init__(self, dynprompt, outputs_cache):
@ -57,7 +59,7 @@ class IsChangedCache:
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"): if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False self.is_changed[node_id] = False
return self.is_changed[node_id] return self.is_changed[node_id]
@ -325,7 +327,7 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
parent_node_id = dynprompt.get_parent_node_id(unique_id) parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None: if caches.outputs.get(unique_id) is not None:
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_output = caches.ui.get(unique_id) or {}
@ -433,7 +435,7 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node # Figure out if the newly created node is an output node
class_type = node_info["class_type"] class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id) new_output_ids.append(node_id)
for i in range(len(node_outputs)): for i in range(len(node_outputs)):
@ -646,7 +648,7 @@ class PromptExecutor:
def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]: def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]:
for _, node in prompt.items(): for _, node in prompt.items():
yield nodes.NODE_CLASS_MAPPINGS[node['class_type']] yield get_nodes().NODE_CLASS_MAPPINGS[node['class_type']]
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple: def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
@ -658,7 +660,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = get_nodes().NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES() class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {}))) valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
@ -714,7 +716,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
o_id = val[0] o_id = val[0]
o_class_type = prompt[o_id]['class_type'] o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = get_nodes().NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]] received_type = r[val[1]]
received_types[x] = received_type received_types[x] = received_type
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple)) any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
@ -929,7 +931,7 @@ def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple
return ValidationTuple(False, error, [], []) return ValidationTuple(False, error, [], [])
class_type = prompt[x]['class_type'] class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) class_ = get_nodes().NODE_CLASS_MAPPINGS.get(class_type, None)
if class_ is None: if class_ is None:
error = { error = {
"type": "invalid_prompt", "type": "invalid_prompt",

View File

@ -2,11 +2,10 @@ import asyncio
import itertools import itertools
import logging import logging
import os import os
from concurrent.futures import ProcessPoolExecutor
from .main_pre import args
from .extra_model_paths import load_extra_path_config from .extra_model_paths import load_extra_path_config
from ..distributed.executors import ContextVarExecutor from .main_pre import args
from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor
async def main(): async def main():
@ -42,9 +41,18 @@ async def main():
load_extra_path_config(config_path) load_extra_path_config(config_path)
from ..distributed.distributed_prompt_worker import DistributedPromptWorker from ..distributed.distributed_prompt_worker import DistributedPromptWorker
if args.executor_factory in ("ThreadPoolExecutor", "ContextVarExecutor"):
executor = ContextVarExecutor()
elif args.executor_factory in ("ProcessPoolExecutor", "ContextVarProcessPoolExecutor"):
executor = ContextVarProcessPoolExecutor()
else:
# default executor
executor = ContextVarExecutor()
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name, queue_name=args.distributed_queue_name,
executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)): executor=executor):
stop = asyncio.Event() stop = asyncio.Event()
try: try:
await stop.wait() await stop.wait()

View File

@ -0,0 +1,196 @@
# Pickling support for contextvars.Context objects
# Copyright (c) 2021 Anselm Kruis
#
# This library is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA.
'''
:mod:`cvpickle` --- make :class:`contextvars.Context` picklable
Pickling of :class:`~contextvars.Context` objects is not possible by default for two reasons, given in
https://www.python.org/dev/peps/pep-0567/#making-context-objects-picklable:
1. ContextVar objects do not have __module__ and __qualname__ attributes,
making straightforward pickling of Context objects impossible.
2. Not all context variables refer to picklable objects. Making a ContextVar
picklable must be an opt-in.
The module :mod:`cvpickle` provides a reducer (class :class:`ContextReducer`) for context objects.
You have to register a ContextVar with the reducer to get it pickled.
For convenience, the module provides a global :class:`ContextReducer` object in
:data:`cvpickle.global_context_reducer` and ContextVar (un-)registration functions
:func:`cvpickle.register_contextvar` and :func:`cvpickle.deregister_contextvar`
A minimal example:
>>> import cvpickle
>>> import contextvars
>>>
>>> my_context_var = contextvars.ContextVar("my_context_var")
>>> cvpickle.register_contextvar(my_context_var, __name__)
'''
import contextvars
import copyreg
import importlib
import types
from pickle import _getattribute
class _ContextVarProxy:
def __init__(self, module_name, qualname):
self.module_name = module_name
self.qualname = qualname
def _context_factory(cls, mapping):
if cls is None:
context = contextvars.Context()
else:
context = cls()
for (modulename, qualname), value in mapping.items():
module = importlib.import_module(modulename)
cv = _getattribute(module, qualname)[0]
context.run(cv.set, value)
return context
class ContextReducer:
"""A *ContextReducer* object is a "reduction" function for a :class:`~contextvars.Context` object.
An *ContextReducer* object knows which context variables can be pickled.
"""
def __init__(self, *, auto_register=False, factory_is_copy_context=False):
# contextvars.ContextVar is hashable, but it is not possible to create a weak reference
# to a ContextVar (as of Python 3.7.1). Therefore we use a regular dictionary instead of
# weakref.WeakKeyDictionary(). That's no problem, because deleting a ContextVar leaks
# references anyway
self.picklable_contextvars = {}
#: If set to :data:`True`, call :func:`copyreg.pickle` to declare this *ContextReducer* as
#: "reduction" function for :class:`~contextvars.Context` objects, when the
#: :meth:`register_contextvar` is called for the first time.
self.auto_register = auto_register
#: If set to :data:`True`, use :func:`contextvars.copy_context` to create a new
#: :class:`~contextvars.Context` object upon unpickling. This way the unpickled
#: context variables are added to the existing context variables.
self.factory_is_copy_context = factory_is_copy_context
def __call__(self, context):
"""Reduce a contextvars.Context object
"""
if not isinstance(context, contextvars.Context):
raise TypeError('Argument must be a Context object not {}'.format(type(context).__name__))
cvars = {}
for cv, value in context.items():
mod_and_name = self.picklable_contextvars.get(cv)
if mod_and_name is not None:
cvars[mod_and_name] = value
if self.factory_is_copy_context:
cls = contextvars.copy_context
else:
cls = type(context)
if cls is contextvars.Context:
# class contextvars.Context can't be pickled, because its __module__ is 'builtins' (Python 3.7.5)
cls = None
return _context_factory, (cls, cvars)
def register_contextvar(self, contextvar, module, qualname=None, *, validate=True):
"""Register *contextvar* with this :class:`ContextReducer`
Declare, that the context variable *contextvar* can be pickled.
:param contextvar: a context variable
:type contextvar: :class:`~contextvars.ContextVar`
:param module: the module object or the module name, where *contextvar* is declared
:type module: :class:`~types.ModuleType` or :class:`str`
:param qualname: the qualified name of *contextvar* in *module*. If unset, *contextvar.name* is used.
:type qualname: :class:`str`
:param validate: if true, check that *contextvar* can be accessed as *module.qualname*.
:type validate: :class:`boolean`
:raises TypeError: if *contextvar* is not an instance of :class:`~contextvars.ContextVar`
:raises ValueError: if *contextvar* is not *module.qualname*.
"""
if not isinstance(contextvar, contextvars.ContextVar):
raise TypeError('Argument 1 must be a ContextVar object not {}'.format(type(contextvar).__name__))
modulename = module
is_module = isinstance(module, types.ModuleType)
if is_module:
modulename = module.__name__
if qualname is None:
qualname = contextvar.name
if validate:
if not is_module:
module = importlib.import_module(modulename)
v = _getattribute(module, qualname)[0] # raises AttributeError
if v is not contextvar:
raise ValueError('Not the same object: ContextVar {} and global {}.{}'.format(contextvar.name, modulename, qualname))
self.picklable_contextvars[contextvar] = (modulename, qualname)
if self.auto_register:
self.auto_register = False
copyreg.pickle(contextvars.Context, self)
# in case of stackless python enable context pickling
try:
from stackless import PICKLEFLAGS_PICKLE_CONTEXT, pickle_flags, pickle_flags_default
except ImportError:
pass
else:
pickle_flags(PICKLEFLAGS_PICKLE_CONTEXT, PICKLEFLAGS_PICKLE_CONTEXT)
pickle_flags_default(PICKLEFLAGS_PICKLE_CONTEXT, PICKLEFLAGS_PICKLE_CONTEXT)
def deregister_contextvar(self, contextvar):
"""Deregister *contextvar* from this :class:`ContextReducer`
Declare, that the context variable *contextvar* can't be pickled.
:param contextvar: a context variable
:type contextvar: :class:`~contextvars.ContextVar`
:raises KeyError: if *contextvar* hasn't been registered.
"""
del self.picklable_contextvars[contextvar]
#: A global :class:`ContextReducer` object.
#:
#: The attributes are set as follows
#:
#: * :attr:`~ContextReducer.auto_register`: :data:`True`
#: * :attr:`~ContextReducer.factory_is_copy_context`: :data:`True`
#:
#: :meta hide-value:
#:
global_context_reducer = ContextReducer(auto_register=True, factory_is_copy_context=True)
def register_contextvar(contextvar, module, qualname=None, *, validate=True):
"""Register *contextvar* with :data:`global_context_reducer`
See :meth:`ContextReducer.register_contextvar`.
"""
return global_context_reducer.register_contextvar(contextvar, module, qualname, validate=validate)
def deregister_contextvar(contextvar):
"""Deregister *contextvar* from :data:`global_context_reducer`
See :meth:`ContextReducer.deregister_contextvar`.
"""
return global_context_reducer.deregister_contextvar(contextvar)

View File

@ -1,7 +1,4 @@
import asyncio
import concurrent
import contextvars import contextvars
import threading
import typing import typing
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial from functools import partial
@ -20,7 +17,4 @@ class ContextVarExecutor(ThreadPoolExecutor):
class ContextVarProcessPoolExecutor(ProcessPoolExecutor): class ContextVarProcessPoolExecutor(ProcessPoolExecutor):
pass
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
# TODO: serialize the "comfyui_execution_context"
pass

View File

@ -1,15 +1,28 @@
import concurrent.futures import concurrent.futures
from typing import Callable import contextvars
import multiprocessing
import pickle
from functools import partial
from typing import Callable, Any
from pebble import ProcessPool, ProcessFuture from pebble import ProcessPool, ProcessFuture
from ..component_model.executor_types import Executor, ExecutePromptArgs from ..component_model.executor_types import Executor
def _wrap_with_context(context_data: bytes, func: Callable, *args, **kwargs) -> Any:
new_ctx: contextvars.Context = pickle.loads(context_data)
return new_ctx.run(func, *args, **kwargs)
class ProcessPoolExecutor(ProcessPool, Executor): class ProcessPoolExecutor(ProcessPool, Executor):
def __init__(self, max_workers: int = 1): def __init__(self,
super().__init__(max_workers=1) max_workers: int = 1,
max_tasks: int = 0,
initializer: Callable = None,
initargs: list | tuple = (),
context: multiprocessing.context.BaseContext = multiprocessing):
super().__init__(max_workers=max_workers, max_tasks=max_tasks, initializer=initializer, initargs=initargs, context=context)
def shutdown(self, wait=True, *, cancel_futures=False): def shutdown(self, wait=True, *, cancel_futures=False):
if cancel_futures: if cancel_futures:
@ -21,19 +34,16 @@ class ProcessPoolExecutor(ProcessPool, Executor):
return return
def schedule(self, function: Callable, def schedule(self, function: Callable,
args: list = (), args: list | tuple = (),
kwargs=None, kwargs=None,
timeout: float = None) -> ProcessFuture: timeout: float = None) -> ProcessFuture:
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
# try:
# args: ExecutePromptArgs
# prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
#
# except ValueError:
# pass
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
return super().schedule(function, args, kwargs, timeout)
context_bin = pickle.dumps(contextvars.copy_context())
unpack_context_then_run_function = partial(_wrap_with_context, context_bin, function)
return super().schedule(unpack_context_then_run_function, args=args, kwargs=kwargs, timeout=timeout)
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None) return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)

View File

@ -5,36 +5,41 @@ from contextvars import ContextVar
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import Optional, Final from typing import Optional, Final
from .component_model import cvpickle
from .component_model.executor_types import ExecutorToClientProgress from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames from .component_model.folder_path_types import FolderNames
from .distributed.server_stub import ServerStub from .distributed.server_stub import ServerStub
from .nodes.package_typing import ExportedNodes, exported_nodes_view
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context") comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
# enables context var propagation across process boundaries for process pool executors
cvpickle.register_contextvar(comfyui_execution_context, __name__)
@dataclass(frozen=True) @dataclass(frozen=True)
class ExecutionContext: class ExecutionContext:
server: ExecutorToClientProgress server: ExecutorToClientProgress
folder_names_and_paths: FolderNames folder_names_and_paths: FolderNames
custom_nodes: ExportedNodes
node_id: Optional[str] = None node_id: Optional[str] = None
task_id: Optional[str] = None task_id: Optional[str] = None
inference_mode: bool = True inference_mode: bool = True
_current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True))) comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes()))
def current_execution_context() -> ExecutionContext: def current_execution_context() -> ExecutionContext:
return _current_context.get() return comfyui_execution_context.get()
@contextmanager @contextmanager
def _new_execution_context(ctx: ExecutionContext): def _new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx) token = comfyui_execution_context.set(ctx)
try: try:
yield ctx yield ctx
finally: finally:
_current_context.reset(token) comfyui_execution_context.reset(token)
@contextmanager @contextmanager
@ -59,3 +64,24 @@ def context_execute_node(node_id: str):
new_ctx = replace(current_ctx, node_id=node_id) new_ctx = replace(current_ctx, node_id=node_id)
with _new_execution_context(new_ctx): with _new_execution_context(new_ctx):
yield new_ctx yield new_ctx
@contextmanager
def context_add_custom_nodes(exported_nodes: ExportedNodes):
"""
Adds custom nodes to the execution context
:param exported_nodes: an object that represents a gathering of custom node export symbols
:return: a context
"""
current_ctx = current_execution_context()
if len(exported_nodes) == 0:
yield current_ctx
if len(current_ctx.custom_nodes) == 0:
merged_custom_nodes = exported_nodes
else:
merged_custom_nodes = exported_nodes_view(current_ctx.custom_nodes, exported_nodes)
new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes)
with _new_execution_context(new_ctx):
yield new_ctx

View File

@ -1,9 +1,9 @@
from typing import Optional from typing import Optional
from .cmd.execution import nodes
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
DependencyExecutionErrorMessage DependencyExecutionErrorMessage
from .graph_utils import is_link from .graph_utils import is_link
from .nodes_context import get_nodes
class DynamicPrompt: class DynamicPrompt:
@ -82,7 +82,7 @@ class TopologicalSort:
def get_input_info(self, unique_id, input_name): def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"] class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name) return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input): def make_input_strong_link(self, to_node_id, to_input):
@ -197,7 +197,7 @@ class ExecutionList(TopologicalSort):
# Some other heuristics could probably be used here to improve the UX further. # Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id): def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"] class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True return True
return False return False

View File

@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
from collections import ChainMap
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type, runtime_checkable Callable, List, Type, MutableMapping
from typing_extensions import TypedDict, NotRequired from typing_extensions import TypedDict, NotRequired
@ -159,9 +160,9 @@ class CustomNode(Protocol):
@dataclass @dataclass
class ExportedNodes: class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict) NODE_CLASS_MAPPINGS: MutableMapping[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict) NODE_DISPLAY_NAME_MAPPINGS: MutableMapping[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict) EXTENSION_WEB_DIRS: MutableMapping[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes: def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS) self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
@ -187,3 +188,28 @@ class ExportedNodes:
def __add__(self, other): def __add__(self, other):
exported_nodes = ExportedNodes().update(self) exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other) return exported_nodes.update(other)
class _ExportedNodesAsChainMap(ExportedNodes):
NODE_CLASS_MAPPINGS: ChainMap[str, CustomNode] = field(default_factory=ChainMap)
NODE_DISPLAY_NAME_MAPPINGS: ChainMap[str, str] = field(default_factory=ChainMap)
EXTENSION_WEB_DIRS: ChainMap[str, str] = field(default_factory=ChainMap)
@classmethod
def from_iter(cls, *exported_nodes: ExportedNodes):
en = _ExportedNodesAsChainMap()
en.NODE_CLASS_MAPPINGS = ChainMap(*[ncm.NODE_CLASS_MAPPINGS for ncm in exported_nodes])
en.NODE_DISPLAY_NAME_MAPPINGS = ChainMap(*[ncm.NODE_DISPLAY_NAME_MAPPINGS for ncm in exported_nodes])
en.EXTENSION_WEB_DIRS = ChainMap(*[ncm.EXTENSION_WEB_DIRS for ncm in exported_nodes])
return en
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS = self.NODE_CLASS_MAPPINGS.new_child(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS = self.NODE_DISPLAY_NAME_MAPPINGS.new_child(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS = self.EXTENSION_WEB_DIRS.new_child(exported_nodes.EXTENSION_WEB_DIRS)
return self
def exported_nodes_view(*exported_nodes: ExportedNodes) -> ExportedNodes:
"""Gets a view of all the provided exported nodes, concatenating them together using a ChainMap internally"""
return _ExportedNodesAsChainMap.from_iter(*exported_nodes)

15
comfy/nodes_context.py Normal file
View File

@ -0,0 +1,15 @@
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
import lazy_object_proxy
from comfy.execution_context import current_execution_context
from comfy.nodes.package import import_all_nodes_in_workspace
from comfy.nodes.package_typing import ExportedNodes, exported_nodes_view
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_nodes() -> ExportedNodes:
current_ctx = current_execution_context()
if len(current_ctx.custom_nodes) == 0:
return nodes
return exported_nodes_view(nodes, current_ctx.custom_nodes)

View File

@ -13,6 +13,7 @@ from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, ExecutingMessage, ExecutionErrorMessage, \ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, ExecutingMessage, ExecutionErrorMessage, \
DependencyCycleError DependencyCycleError
from comfy.distributed.server_stub import ServerStub from comfy.distributed.server_stub import ServerStub
from comfy.execution_context import context_add_custom_nodes
from comfy.graph_utils import GraphBuilder, Node from comfy.graph_utils import GraphBuilder, Node
from comfy.nodes.package_typing import ExportedNodes from comfy.nodes.package_typing import ExportedNodes
@ -108,16 +109,15 @@ class TestExecution:
(100,), (100,),
]) ])
async def client(self, request) -> ComfyClient: async def client(self, request) -> ComfyClient:
from comfy.cmd.execution import nodes
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS))
lru_size, = request.param lru_size, = request.param
configuration = default_configuration() configuration = default_configuration()
configuration.cache_lru = lru_size configuration.cache_lru = lru_size
progress_handler = _ProgressHandler() progress_handler = _ProgressHandler()
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client: with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)):
yield ComfyClient(embedded_client, progress_handler) async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
yield ComfyClient(embedded_client, progress_handler)
@fixture @fixture
def builder(self, request): def builder(self, request):

View File

@ -7,9 +7,11 @@ import torch
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.cmd.execution import nodes
from comfy.component_model.make_mutable import make_mutable from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.tensor_types import RGBImageBatch from comfy.component_model.tensor_types import RGBImageBatch
from comfy.distributed.executors import ContextVarExecutor
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.execution_context import context_add_custom_nodes
from comfy.nodes.package_typing import CustomNode, ExportedNodes from comfy.nodes.package_typing import CustomNode, ExportedNodes
@ -80,7 +82,7 @@ async def test_separate_thread_callback():
# Custom test exception that we'll configure to panic on # Custom test exception that we'll configure to panic on
class TestUnrecoverableError(Exception): class UnrecoverableError(Exception):
pass pass
@ -102,7 +104,7 @@ class TestExceptionNode(CustomNode):
def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]: def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]:
if should_raise: if should_raise:
raise TestUnrecoverableError("Test exception from node") raise UnrecoverableError("Test exception from node")
else: else:
# Return a dummy image if not raising # Return a dummy image if not raising
return (torch.zeros([1, 64, 64, 3]),) return (torch.zeros([1, 64, 64, 3]),)
@ -117,6 +119,11 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestExceptionNode": "Test Exception Node", "TestExceptionNode": "Test Exception Node",
} }
EXECUTOR_FACTORIES = [
(ContextVarExecutor, {"max_workers": 1}),
(ProcessPoolExecutor, {"max_workers": 1}),
]
def create_failing_workflow(): def create_failing_workflow():
"""Create a workflow that uses our test node to raise an exception""" """Create a workflow that uses our test node to raise an exception"""
@ -131,52 +138,113 @@ def create_failing_workflow():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_panic_on_exception(): @pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES)
# Set up the test nodes async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs):
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)) """Test panic behavior with different executor types"""
# Create configuration with our test exception in panic_when # Create configuration with our test exception in panic_when
config = Configuration() config = Configuration()
config.panic_when = [f"{__name__}.TestUnrecoverableError"] config.panic_when = [f"{__name__}.UnrecoverableError"]
# Initialize the specific executor
executor = executor_cls(**executor_kwargs)
# Mock sys.exit to prevent actual exit and verify it's called # Mock sys.exit to prevent actual exit and verify it's called
with patch('sys.exit') as mock_exit: with (context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)),
patch('sys.exit') as mock_exit):
try: try:
async with EmbeddedComfyClient(configuration=config) as client: async with EmbeddedComfyClient(configuration=config, executor=executor) as client:
# Queue our failing workflow # Queue our failing workflow
await client.queue_prompt(create_failing_workflow()) await client.queue_prompt(create_failing_workflow())
except TestUnrecoverableError: except SystemExit:
sys_exit_called = True
except UnrecoverableError:
# We expect the exception to be raised here # We expect the exception to be raised here
pass sys_exit_called = False
# Give the event loop a chance to process the exit callback # Give the event loop a chance to process the exit callback
await asyncio.sleep(0) await asyncio.sleep(0)
# Verify sys.exit was called with code 1 # Verify sys.exit was called with code 1
mock_exit.assert_called_once_with(1) if executor_cls == ProcessPoolExecutor:
assert sys_exit_called
else:
mock_exit.assert_called_once_with(1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_panic_when_disabled(): @pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES)
"""Verify that the same exception doesn't trigger exit when not in panic_when""" async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwargs):
# Set up the test nodes """Test no-panic behavior with different executor types"""
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
# Create configuration without the exception in panic_when # Create configuration without the exception in panic_when
config = Configuration() config = Configuration()
# Initialize the specific executor
executor = executor_cls(**executor_kwargs)
# Mock sys.exit to verify it's not called # Mock sys.exit to verify it's not called
with patch('sys.exit') as mock_exit: with (context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)),
patch('sys.exit') as mock_exit):
try: try:
async with EmbeddedComfyClient(configuration=config) as client: async with EmbeddedComfyClient(configuration=config, executor=executor) as client:
# Queue our failing workflow # Queue our failing workflow
await client.queue_prompt(create_failing_workflow()) await client.queue_prompt(create_failing_workflow())
except TestUnrecoverableError: except SystemExit:
sys_exit_called = True
except UnrecoverableError:
# We expect the exception to be raised here # We expect the exception to be raised here
pass sys_exit_called = False
# Give the event loop a chance to process any callbacks # Give the event loop a chance to process any callbacks
await asyncio.sleep(0.1) await asyncio.sleep(0)
# Verify sys.exit was not called # Verify sys.exit was not called
mock_exit.assert_not_called() mock_exit.assert_not_called()
assert not sys_exit_called
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES)
async def test_executor_cleanup(executor_cls, executor_kwargs):
"""Test that executors are properly cleaned up after use"""
executor = executor_cls(**executor_kwargs)
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)):
async with EmbeddedComfyClient(executor=executor) as client:
# Create a simple workflow that doesn't raise
workflow = create_failing_workflow()
workflow["1"]["inputs"]["should_raise"] = False
# Run it
result = await client.queue_prompt(workflow)
assert isinstance(result, dict), "Expected workflow to return results"
# Add a test for parallel execution to verify multi-worker behavior
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_cls,executor_kwargs", [
(ContextVarExecutor, {"max_workers": 2}),
(ProcessPoolExecutor, {"max_workers": 2}),
])
async def test_parallel_execution(executor_cls, executor_kwargs):
"""Test that executors can handle multiple workflows in parallel"""
executor = executor_cls(**executor_kwargs)
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)):
async with EmbeddedComfyClient(executor=executor) as client:
# Create multiple non-failing workflows
workflow = create_failing_workflow()
workflow["1"]["inputs"]["should_raise"] = False
# Run multiple workflows concurrently
results = await asyncio.gather(*[
client.queue_prompt(workflow)
for _ in range(3)
])
assert len(results) == 3, "Expected all workflows to complete"
assert all(isinstance(r, dict) for r in results), "Expected all workflows to return results"

View File

@ -0,0 +1,29 @@
import contextvars
import pytest
from comfy.component_model import cvpickle
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
# Example context variable
my_var = contextvars.ContextVar('my_var', default=None)
cvpickle.register_contextvar(my_var, module=__name__)
def worker_function():
"""Function that runs in worker process and accesses context"""
return my_var.get()
@pytest.mark.asyncio
async def test_context_preservation():
# Set context in parent
my_var.set("test_value")
# Create pool and submit work
with ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(worker_function)
result = future.result()
# Verify context was preserved
assert result == "test_value"

View File

@ -5,7 +5,8 @@ import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from comfy.cli_args import args from comfy.cli_args import args
from comfy.cmd.execution import validate_prompt, nodes from comfy.cmd.execution import validate_prompt
from comfy.nodes_context import nodes
valid_prompt: Final[dict] = { valid_prompt: Final[dict] = {
"1": { "1": {