Better support for process pool executors

- --panics-when=torch.cuda.OutOfMemory will now correctly panic and
   exit the worker, giving it time to reply that the execution failed
   and better dealing with irrecoverable out-of-memory errors
 - --executor-factory=ProcessPoolExecutor will use a process instead of
   a thread to execute comfyui workflows when using the worker. When
   this process panics and exits, it will be correctly replaced, making
   a more robust worker
This commit is contained in:
doctorpangloss 2025-02-18 14:37:20 -08:00
parent 684d180446
commit 3ddec8ae90
14 changed files with 458 additions and 83 deletions

View File

@ -1,9 +1,9 @@
import itertools
from typing import Sequence, Mapping, Dict
from .cmd.execution import nodes
from .graph import DynamicPrompt
from .graph_utils import is_link
from .nodes_context import get_nodes
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
@ -11,7 +11,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
@ -110,7 +110,7 @@ class CacheKeySetInputSignature(CacheKeySet):
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)

View File

@ -14,13 +14,13 @@ from contextlib import nullcontext
from os import PathLike
from typing import List, Optional, Tuple
import lazy_object_proxy
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
from .main_pre import tracer
from .. import interruption
from .. import model_management
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
@ -28,22 +28,24 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
# order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes
from ..validation import validate_node_input
_module_properties = create_module_properties()
@_module_properties.getter
def _nodes():
return get_nodes()
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
@ -57,7 +59,7 @@ class IsChangedCache:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False
return self.is_changed[node_id]
@ -325,7 +327,7 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
@ -433,7 +435,7 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data,
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for i in range(len(node_outputs)):
@ -646,7 +648,7 @@ class PromptExecutor:
def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]:
for _, node in prompt.items():
yield nodes.NODE_CLASS_MAPPINGS[node['class_type']]
yield get_nodes().NODE_CLASS_MAPPINGS[node['class_type']]
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
@ -658,7 +660,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
obj_class = get_nodes().NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
@ -714,7 +716,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
o_id = val[0]
o_class_type = prompt[o_id]['class_type']
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
r = get_nodes().NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
@ -929,7 +931,7 @@ def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple
return ValidationTuple(False, error, [], [])
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
class_ = get_nodes().NODE_CLASS_MAPPINGS.get(class_type, None)
if class_ is None:
error = {
"type": "invalid_prompt",

View File

@ -2,11 +2,10 @@ import asyncio
import itertools
import logging
import os
from concurrent.futures import ProcessPoolExecutor
from .main_pre import args
from .extra_model_paths import load_extra_path_config
from ..distributed.executors import ContextVarExecutor
from .main_pre import args
from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor
async def main():
@ -42,9 +41,18 @@ async def main():
load_extra_path_config(config_path)
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
if args.executor_factory in ("ThreadPoolExecutor", "ContextVarExecutor"):
executor = ContextVarExecutor()
elif args.executor_factory in ("ProcessPoolExecutor", "ContextVarProcessPoolExecutor"):
executor = ContextVarProcessPoolExecutor()
else:
# default executor
executor = ContextVarExecutor()
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name,
executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
executor=executor):
stop = asyncio.Event()
try:
await stop.wait()

View File

@ -0,0 +1,196 @@
# Pickling support for contextvars.Context objects
# Copyright (c) 2021 Anselm Kruis
#
# This library is free software; you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 2.1 of the License, or (at your option)
# any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA.
'''
:mod:`cvpickle` --- make :class:`contextvars.Context` picklable
Pickling of :class:`~contextvars.Context` objects is not possible by default for two reasons, given in
https://www.python.org/dev/peps/pep-0567/#making-context-objects-picklable:
1. ContextVar objects do not have __module__ and __qualname__ attributes,
making straightforward pickling of Context objects impossible.
2. Not all context variables refer to picklable objects. Making a ContextVar
picklable must be an opt-in.
The module :mod:`cvpickle` provides a reducer (class :class:`ContextReducer`) for context objects.
You have to register a ContextVar with the reducer to get it pickled.
For convenience, the module provides a global :class:`ContextReducer` object in
:data:`cvpickle.global_context_reducer` and ContextVar (un-)registration functions
:func:`cvpickle.register_contextvar` and :func:`cvpickle.deregister_contextvar`
A minimal example:
>>> import cvpickle
>>> import contextvars
>>>
>>> my_context_var = contextvars.ContextVar("my_context_var")
>>> cvpickle.register_contextvar(my_context_var, __name__)
'''
import contextvars
import copyreg
import importlib
import types
from pickle import _getattribute
class _ContextVarProxy:
def __init__(self, module_name, qualname):
self.module_name = module_name
self.qualname = qualname
def _context_factory(cls, mapping):
if cls is None:
context = contextvars.Context()
else:
context = cls()
for (modulename, qualname), value in mapping.items():
module = importlib.import_module(modulename)
cv = _getattribute(module, qualname)[0]
context.run(cv.set, value)
return context
class ContextReducer:
"""A *ContextReducer* object is a "reduction" function for a :class:`~contextvars.Context` object.
An *ContextReducer* object knows which context variables can be pickled.
"""
def __init__(self, *, auto_register=False, factory_is_copy_context=False):
# contextvars.ContextVar is hashable, but it is not possible to create a weak reference
# to a ContextVar (as of Python 3.7.1). Therefore we use a regular dictionary instead of
# weakref.WeakKeyDictionary(). That's no problem, because deleting a ContextVar leaks
# references anyway
self.picklable_contextvars = {}
#: If set to :data:`True`, call :func:`copyreg.pickle` to declare this *ContextReducer* as
#: "reduction" function for :class:`~contextvars.Context` objects, when the
#: :meth:`register_contextvar` is called for the first time.
self.auto_register = auto_register
#: If set to :data:`True`, use :func:`contextvars.copy_context` to create a new
#: :class:`~contextvars.Context` object upon unpickling. This way the unpickled
#: context variables are added to the existing context variables.
self.factory_is_copy_context = factory_is_copy_context
def __call__(self, context):
"""Reduce a contextvars.Context object
"""
if not isinstance(context, contextvars.Context):
raise TypeError('Argument must be a Context object not {}'.format(type(context).__name__))
cvars = {}
for cv, value in context.items():
mod_and_name = self.picklable_contextvars.get(cv)
if mod_and_name is not None:
cvars[mod_and_name] = value
if self.factory_is_copy_context:
cls = contextvars.copy_context
else:
cls = type(context)
if cls is contextvars.Context:
# class contextvars.Context can't be pickled, because its __module__ is 'builtins' (Python 3.7.5)
cls = None
return _context_factory, (cls, cvars)
def register_contextvar(self, contextvar, module, qualname=None, *, validate=True):
"""Register *contextvar* with this :class:`ContextReducer`
Declare, that the context variable *contextvar* can be pickled.
:param contextvar: a context variable
:type contextvar: :class:`~contextvars.ContextVar`
:param module: the module object or the module name, where *contextvar* is declared
:type module: :class:`~types.ModuleType` or :class:`str`
:param qualname: the qualified name of *contextvar* in *module*. If unset, *contextvar.name* is used.
:type qualname: :class:`str`
:param validate: if true, check that *contextvar* can be accessed as *module.qualname*.
:type validate: :class:`boolean`
:raises TypeError: if *contextvar* is not an instance of :class:`~contextvars.ContextVar`
:raises ValueError: if *contextvar* is not *module.qualname*.
"""
if not isinstance(contextvar, contextvars.ContextVar):
raise TypeError('Argument 1 must be a ContextVar object not {}'.format(type(contextvar).__name__))
modulename = module
is_module = isinstance(module, types.ModuleType)
if is_module:
modulename = module.__name__
if qualname is None:
qualname = contextvar.name
if validate:
if not is_module:
module = importlib.import_module(modulename)
v = _getattribute(module, qualname)[0] # raises AttributeError
if v is not contextvar:
raise ValueError('Not the same object: ContextVar {} and global {}.{}'.format(contextvar.name, modulename, qualname))
self.picklable_contextvars[contextvar] = (modulename, qualname)
if self.auto_register:
self.auto_register = False
copyreg.pickle(contextvars.Context, self)
# in case of stackless python enable context pickling
try:
from stackless import PICKLEFLAGS_PICKLE_CONTEXT, pickle_flags, pickle_flags_default
except ImportError:
pass
else:
pickle_flags(PICKLEFLAGS_PICKLE_CONTEXT, PICKLEFLAGS_PICKLE_CONTEXT)
pickle_flags_default(PICKLEFLAGS_PICKLE_CONTEXT, PICKLEFLAGS_PICKLE_CONTEXT)
def deregister_contextvar(self, contextvar):
"""Deregister *contextvar* from this :class:`ContextReducer`
Declare, that the context variable *contextvar* can't be pickled.
:param contextvar: a context variable
:type contextvar: :class:`~contextvars.ContextVar`
:raises KeyError: if *contextvar* hasn't been registered.
"""
del self.picklable_contextvars[contextvar]
#: A global :class:`ContextReducer` object.
#:
#: The attributes are set as follows
#:
#: * :attr:`~ContextReducer.auto_register`: :data:`True`
#: * :attr:`~ContextReducer.factory_is_copy_context`: :data:`True`
#:
#: :meta hide-value:
#:
global_context_reducer = ContextReducer(auto_register=True, factory_is_copy_context=True)
def register_contextvar(contextvar, module, qualname=None, *, validate=True):
"""Register *contextvar* with :data:`global_context_reducer`
See :meth:`ContextReducer.register_contextvar`.
"""
return global_context_reducer.register_contextvar(contextvar, module, qualname, validate=validate)
def deregister_contextvar(contextvar):
"""Deregister *contextvar* from :data:`global_context_reducer`
See :meth:`ContextReducer.deregister_contextvar`.
"""
return global_context_reducer.deregister_contextvar(contextvar)

View File

@ -1,7 +1,4 @@
import asyncio
import concurrent
import contextvars
import threading
import typing
from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial
@ -20,7 +17,4 @@ class ContextVarExecutor(ThreadPoolExecutor):
class ContextVarProcessPoolExecutor(ProcessPoolExecutor):
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
# TODO: serialize the "comfyui_execution_context"
pass
pass

View File

@ -1,15 +1,28 @@
import concurrent.futures
from typing import Callable
import contextvars
import multiprocessing
import pickle
from functools import partial
from typing import Callable, Any
from pebble import ProcessPool, ProcessFuture
from ..component_model.executor_types import Executor, ExecutePromptArgs
from ..component_model.executor_types import Executor
def _wrap_with_context(context_data: bytes, func: Callable, *args, **kwargs) -> Any:
new_ctx: contextvars.Context = pickle.loads(context_data)
return new_ctx.run(func, *args, **kwargs)
class ProcessPoolExecutor(ProcessPool, Executor):
def __init__(self, max_workers: int = 1):
super().__init__(max_workers=1)
def __init__(self,
max_workers: int = 1,
max_tasks: int = 0,
initializer: Callable = None,
initargs: list | tuple = (),
context: multiprocessing.context.BaseContext = multiprocessing):
super().__init__(max_workers=max_workers, max_tasks=max_tasks, initializer=initializer, initargs=initargs, context=context)
def shutdown(self, wait=True, *, cancel_futures=False):
if cancel_futures:
@ -21,19 +34,16 @@ class ProcessPoolExecutor(ProcessPool, Executor):
return
def schedule(self, function: Callable,
args: list = (),
args: list | tuple = (),
kwargs=None,
timeout: float = None) -> ProcessFuture:
# todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different
# try:
# args: ExecutePromptArgs
# prompt, prompt_id, client_id, span_context, progress_handler, configuration = args
#
# except ValueError:
# pass
if kwargs is None:
kwargs = {}
return super().schedule(function, args, kwargs, timeout)
context_bin = pickle.dumps(contextvars.copy_context())
unpack_context_then_run_function = partial(_wrap_with_context, context_bin, function)
return super().schedule(unpack_context_then_run_function, args=args, kwargs=kwargs, timeout=timeout)
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None)

View File

@ -5,36 +5,41 @@ from contextvars import ContextVar
from dataclasses import dataclass, replace
from typing import Optional, Final
from .component_model import cvpickle
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames
from .distributed.server_stub import ServerStub
from .nodes.package_typing import ExportedNodes, exported_nodes_view
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
# enables context var propagation across process boundaries for process pool executors
cvpickle.register_contextvar(comfyui_execution_context, __name__)
@dataclass(frozen=True)
class ExecutionContext:
server: ExecutorToClientProgress
folder_names_and_paths: FolderNames
custom_nodes: ExportedNodes
node_id: Optional[str] = None
task_id: Optional[str] = None
inference_mode: bool = True
_current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True)))
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes()))
def current_execution_context() -> ExecutionContext:
return _current_context.get()
return comfyui_execution_context.get()
@contextmanager
def _new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx)
token = comfyui_execution_context.set(ctx)
try:
yield ctx
finally:
_current_context.reset(token)
comfyui_execution_context.reset(token)
@contextmanager
@ -59,3 +64,24 @@ def context_execute_node(node_id: str):
new_ctx = replace(current_ctx, node_id=node_id)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_add_custom_nodes(exported_nodes: ExportedNodes):
"""
Adds custom nodes to the execution context
:param exported_nodes: an object that represents a gathering of custom node export symbols
:return: a context
"""
current_ctx = current_execution_context()
if len(exported_nodes) == 0:
yield current_ctx
if len(current_ctx.custom_nodes) == 0:
merged_custom_nodes = exported_nodes
else:
merged_custom_nodes = exported_nodes_view(current_ctx.custom_nodes, exported_nodes)
new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes)
with _new_execution_context(new_ctx):
yield new_ctx

View File

@ -1,9 +1,9 @@
from typing import Optional
from .cmd.execution import nodes
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
DependencyExecutionErrorMessage
from .graph_utils import is_link
from .nodes_context import get_nodes
class DynamicPrompt:
@ -82,7 +82,7 @@ class TopologicalSort:
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
@ -197,7 +197,7 @@ class ExecutionList(TopologicalSort):
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
return False

View File

@ -1,8 +1,9 @@
from __future__ import annotations
from collections import ChainMap
from dataclasses import dataclass, field
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type, runtime_checkable
Callable, List, Type, MutableMapping
from typing_extensions import TypedDict, NotRequired
@ -159,9 +160,9 @@ class CustomNode(Protocol):
@dataclass
class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict)
NODE_CLASS_MAPPINGS: MutableMapping[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: MutableMapping[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: MutableMapping[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
@ -187,3 +188,28 @@ class ExportedNodes:
def __add__(self, other):
exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other)
class _ExportedNodesAsChainMap(ExportedNodes):
NODE_CLASS_MAPPINGS: ChainMap[str, CustomNode] = field(default_factory=ChainMap)
NODE_DISPLAY_NAME_MAPPINGS: ChainMap[str, str] = field(default_factory=ChainMap)
EXTENSION_WEB_DIRS: ChainMap[str, str] = field(default_factory=ChainMap)
@classmethod
def from_iter(cls, *exported_nodes: ExportedNodes):
en = _ExportedNodesAsChainMap()
en.NODE_CLASS_MAPPINGS = ChainMap(*[ncm.NODE_CLASS_MAPPINGS for ncm in exported_nodes])
en.NODE_DISPLAY_NAME_MAPPINGS = ChainMap(*[ncm.NODE_DISPLAY_NAME_MAPPINGS for ncm in exported_nodes])
en.EXTENSION_WEB_DIRS = ChainMap(*[ncm.EXTENSION_WEB_DIRS for ncm in exported_nodes])
return en
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS = self.NODE_CLASS_MAPPINGS.new_child(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS = self.NODE_DISPLAY_NAME_MAPPINGS.new_child(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS = self.EXTENSION_WEB_DIRS.new_child(exported_nodes.EXTENSION_WEB_DIRS)
return self
def exported_nodes_view(*exported_nodes: ExportedNodes) -> ExportedNodes:
"""Gets a view of all the provided exported nodes, concatenating them together using a ChainMap internally"""
return _ExportedNodesAsChainMap.from_iter(*exported_nodes)

15
comfy/nodes_context.py Normal file
View File

@ -0,0 +1,15 @@
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
import lazy_object_proxy
from comfy.execution_context import current_execution_context
from comfy.nodes.package import import_all_nodes_in_workspace
from comfy.nodes.package_typing import ExportedNodes, exported_nodes_view
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_nodes() -> ExportedNodes:
current_ctx = current_execution_context()
if len(current_ctx.custom_nodes) == 0:
return nodes
return exported_nodes_view(nodes, current_ctx.custom_nodes)

View File

@ -13,6 +13,7 @@ from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, ExecutingMessage, ExecutionErrorMessage, \
DependencyCycleError
from comfy.distributed.server_stub import ServerStub
from comfy.execution_context import context_add_custom_nodes
from comfy.graph_utils import GraphBuilder, Node
from comfy.nodes.package_typing import ExportedNodes
@ -108,16 +109,15 @@ class TestExecution:
(100,),
])
async def client(self, request) -> ComfyClient:
from comfy.cmd.execution import nodes
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS))
lru_size, = request.param
configuration = default_configuration()
configuration.cache_lru = lru_size
progress_handler = _ProgressHandler()
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
yield ComfyClient(embedded_client, progress_handler)
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)):
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
yield ComfyClient(embedded_client, progress_handler)
@fixture
def builder(self, request):

View File

@ -7,9 +7,11 @@ import torch
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
from comfy.cmd.execution import nodes
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.tensor_types import RGBImageBatch
from comfy.distributed.executors import ContextVarExecutor
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.execution_context import context_add_custom_nodes
from comfy.nodes.package_typing import CustomNode, ExportedNodes
@ -80,7 +82,7 @@ async def test_separate_thread_callback():
# Custom test exception that we'll configure to panic on
class TestUnrecoverableError(Exception):
class UnrecoverableError(Exception):
pass
@ -102,7 +104,7 @@ class TestExceptionNode(CustomNode):
def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]:
if should_raise:
raise TestUnrecoverableError("Test exception from node")
raise UnrecoverableError("Test exception from node")
else:
# Return a dummy image if not raising
return (torch.zeros([1, 64, 64, 3]),)
@ -117,6 +119,11 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestExceptionNode": "Test Exception Node",
}
EXECUTOR_FACTORIES = [
(ContextVarExecutor, {"max_workers": 1}),
(ProcessPoolExecutor, {"max_workers": 1}),
]
def create_failing_workflow():
"""Create a workflow that uses our test node to raise an exception"""
@ -131,52 +138,113 @@ def create_failing_workflow():
@pytest.mark.asyncio
async def test_panic_on_exception():
# Set up the test nodes
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES)
async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs):
"""Test panic behavior with different executor types"""
# Create configuration with our test exception in panic_when
config = Configuration()
config.panic_when = [f"{__name__}.TestUnrecoverableError"]
config.panic_when = [f"{__name__}.UnrecoverableError"]
# Initialize the specific executor
executor = executor_cls(**executor_kwargs)
# Mock sys.exit to prevent actual exit and verify it's called
with patch('sys.exit') as mock_exit:
with (context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)),
patch('sys.exit') as mock_exit):
try:
async with EmbeddedComfyClient(configuration=config) as client:
async with EmbeddedComfyClient(configuration=config, executor=executor) as client:
# Queue our failing workflow
await client.queue_prompt(create_failing_workflow())
except TestUnrecoverableError:
except SystemExit:
sys_exit_called = True
except UnrecoverableError:
# We expect the exception to be raised here
pass
sys_exit_called = False
# Give the event loop a chance to process the exit callback
await asyncio.sleep(0)
# Verify sys.exit was called with code 1
mock_exit.assert_called_once_with(1)
if executor_cls == ProcessPoolExecutor:
assert sys_exit_called
else:
mock_exit.assert_called_once_with(1)
@pytest.mark.asyncio
async def test_no_panic_when_disabled():
"""Verify that the same exception doesn't trigger exit when not in panic_when"""
# Set up the test nodes
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS))
@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES)
async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwargs):
"""Test no-panic behavior with different executor types"""
# Create configuration without the exception in panic_when
config = Configuration()
# Initialize the specific executor
executor = executor_cls(**executor_kwargs)
# Mock sys.exit to verify it's not called
with patch('sys.exit') as mock_exit:
with (context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)),
patch('sys.exit') as mock_exit):
try:
async with EmbeddedComfyClient(configuration=config) as client:
async with EmbeddedComfyClient(configuration=config, executor=executor) as client:
# Queue our failing workflow
await client.queue_prompt(create_failing_workflow())
except TestUnrecoverableError:
except SystemExit:
sys_exit_called = True
except UnrecoverableError:
# We expect the exception to be raised here
pass
sys_exit_called = False
# Give the event loop a chance to process any callbacks
await asyncio.sleep(0.1)
await asyncio.sleep(0)
# Verify sys.exit was not called
mock_exit.assert_not_called()
assert not sys_exit_called
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES)
async def test_executor_cleanup(executor_cls, executor_kwargs):
"""Test that executors are properly cleaned up after use"""
executor = executor_cls(**executor_kwargs)
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)):
async with EmbeddedComfyClient(executor=executor) as client:
# Create a simple workflow that doesn't raise
workflow = create_failing_workflow()
workflow["1"]["inputs"]["should_raise"] = False
# Run it
result = await client.queue_prompt(workflow)
assert isinstance(result, dict), "Expected workflow to return results"
# Add a test for parallel execution to verify multi-worker behavior
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_cls,executor_kwargs", [
(ContextVarExecutor, {"max_workers": 2}),
(ProcessPoolExecutor, {"max_workers": 2}),
])
async def test_parallel_execution(executor_cls, executor_kwargs):
"""Test that executors can handle multiple workflows in parallel"""
executor = executor_cls(**executor_kwargs)
with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)):
async with EmbeddedComfyClient(executor=executor) as client:
# Create multiple non-failing workflows
workflow = create_failing_workflow()
workflow["1"]["inputs"]["should_raise"] = False
# Run multiple workflows concurrently
results = await asyncio.gather(*[
client.queue_prompt(workflow)
for _ in range(3)
])
assert len(results) == 3, "Expected all workflows to complete"
assert all(isinstance(r, dict) for r in results), "Expected all workflows to return results"

View File

@ -0,0 +1,29 @@
import contextvars
import pytest
from comfy.component_model import cvpickle
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
# Example context variable
my_var = contextvars.ContextVar('my_var', default=None)
cvpickle.register_contextvar(my_var, module=__name__)
def worker_function():
"""Function that runs in worker process and accesses context"""
return my_var.get()
@pytest.mark.asyncio
async def test_context_preservation():
# Set context in parent
my_var.set("test_value")
# Create pool and submit work
with ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(worker_function)
result = future.result()
# Verify context was preserved
assert result == "test_value"

View File

@ -5,7 +5,8 @@ import pytest
from pytest_mock import MockerFixture
from comfy.cli_args import args
from comfy.cmd.execution import validate_prompt, nodes
from comfy.cmd.execution import validate_prompt
from comfy.nodes_context import nodes
valid_prompt: Final[dict] = {
"1": {