diff --git a/comfy/caching.py b/comfy/caching.py index 77c3e98f0..c52f47575 100644 --- a/comfy/caching.py +++ b/comfy/caching.py @@ -1,9 +1,9 @@ import itertools from typing import Sequence, Mapping, Dict -from .cmd.execution import nodes from .graph import DynamicPrompt from .graph_utils import is_link +from .nodes_context import get_nodes NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} @@ -11,7 +11,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} def include_unique_id_in_input(class_type: str) -> bool: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] @@ -110,7 +110,7 @@ class CacheKeySetInputSignature(CacheKeySet): return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] signature = [class_type, self.is_changed_cache.get(node_id)] if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 4e9059763..1800daffa 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -14,13 +14,13 @@ from contextlib import nullcontext from os import PathLike from typing import List, Optional, Tuple -import lazy_object_proxy import torch from opentelemetry.trace import get_current_span, StatusCode, Status from .main_pre import tracer from .. import interruption from .. import model_management +from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from ..cli_args import args from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ @@ -28,22 +28,24 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \ HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage from ..component_model.files import canonicalize_path +from ..component_model.module_property import create_module_properties from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import context_execute_node, context_execute_prompt from ..execution_ext import should_panic_on_exception -from ..nodes.package import import_all_nodes_in_workspace -from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode - -# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the -# various functions that are declared here. It should have been a context in the first place. -nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) - # order matters from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from ..graph_utils import is_link, GraphBuilder -from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode +from ..nodes_context import get_nodes from ..validation import validate_node_input +_module_properties = create_module_properties() + + +@_module_properties.getter +def _nodes(): + return get_nodes() + class IsChangedCache: def __init__(self, dynprompt, outputs_cache): @@ -57,7 +59,7 @@ class IsChangedCache: node = self.dynprompt.get_node(node_id) class_type = node["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] if not hasattr(class_def, "IS_CHANGED"): self.is_changed[node_id] = False return self.is_changed[node_id] @@ -325,7 +327,7 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, parent_node_id = dynprompt.get_parent_node_id(unique_id) inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} @@ -433,7 +435,7 @@ def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) # Figure out if the newly created node is an output node class_type = node_info["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: new_output_ids.append(node_id) for i in range(len(node_outputs)): @@ -646,7 +648,7 @@ class PromptExecutor: def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]: for _, node in prompt.items(): - yield nodes.NODE_CLASS_MAPPINGS[node['class_type']] + yield get_nodes().NODE_CLASS_MAPPINGS[node['class_type']] def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple: @@ -658,7 +660,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] - obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] + obj_class = get_nodes().NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {}))) @@ -714,7 +716,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl o_id = val[0] o_class_type = prompt[o_id]['class_type'] - r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES + r = get_nodes().NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple)) @@ -929,7 +931,7 @@ def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple return ValidationTuple(False, error, [], []) class_type = prompt[x]['class_type'] - class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) + class_ = get_nodes().NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: error = { "type": "invalid_prompt", diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py index 1cecdda88..2cebd747f 100644 --- a/comfy/cmd/worker.py +++ b/comfy/cmd/worker.py @@ -2,11 +2,10 @@ import asyncio import itertools import logging import os -from concurrent.futures import ProcessPoolExecutor -from .main_pre import args from .extra_model_paths import load_extra_path_config -from ..distributed.executors import ContextVarExecutor +from .main_pre import args +from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor async def main(): @@ -42,9 +41,18 @@ async def main(): load_extra_path_config(config_path) from ..distributed.distributed_prompt_worker import DistributedPromptWorker + + if args.executor_factory in ("ThreadPoolExecutor", "ContextVarExecutor"): + executor = ContextVarExecutor() + elif args.executor_factory in ("ProcessPoolExecutor", "ContextVarProcessPoolExecutor"): + executor = ContextVarProcessPoolExecutor() + else: + # default executor + executor = ContextVarExecutor() + async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, queue_name=args.distributed_queue_name, - executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)): + executor=executor): stop = asyncio.Event() try: await stop.wait() diff --git a/comfy/component_model/cvpickle.py b/comfy/component_model/cvpickle.py new file mode 100644 index 000000000..e48a7589a --- /dev/null +++ b/comfy/component_model/cvpickle.py @@ -0,0 +1,196 @@ +# Pickling support for contextvars.Context objects +# Copyright (c) 2021 Anselm Kruis +# +# This library is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# This library is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA. + +''' +:mod:`cvpickle` --- make :class:`contextvars.Context` picklable + +Pickling of :class:`~contextvars.Context` objects is not possible by default for two reasons, given in +https://www.python.org/dev/peps/pep-0567/#making-context-objects-picklable: + + 1. ContextVar objects do not have __module__ and __qualname__ attributes, + making straightforward pickling of Context objects impossible. + 2. Not all context variables refer to picklable objects. Making a ContextVar + picklable must be an opt-in. + +The module :mod:`cvpickle` provides a reducer (class :class:`ContextReducer`) for context objects. +You have to register a ContextVar with the reducer to get it pickled. + +For convenience, the module provides a global :class:`ContextReducer` object in +:data:`cvpickle.global_context_reducer` and ContextVar (un-)registration functions +:func:`cvpickle.register_contextvar` and :func:`cvpickle.deregister_contextvar` + +A minimal example: + + >>> import cvpickle + >>> import contextvars + >>> + >>> my_context_var = contextvars.ContextVar("my_context_var") + >>> cvpickle.register_contextvar(my_context_var, __name__) + +''' + +import contextvars +import copyreg +import importlib +import types +from pickle import _getattribute + + +class _ContextVarProxy: + def __init__(self, module_name, qualname): + self.module_name = module_name + self.qualname = qualname + + +def _context_factory(cls, mapping): + if cls is None: + context = contextvars.Context() + else: + context = cls() + + for (modulename, qualname), value in mapping.items(): + module = importlib.import_module(modulename) + cv = _getattribute(module, qualname)[0] + context.run(cv.set, value) + return context + + +class ContextReducer: + """A *ContextReducer* object is a "reduction" function for a :class:`~contextvars.Context` object. + + An *ContextReducer* object knows which context variables can be pickled. + """ + + def __init__(self, *, auto_register=False, factory_is_copy_context=False): + # contextvars.ContextVar is hashable, but it is not possible to create a weak reference + # to a ContextVar (as of Python 3.7.1). Therefore we use a regular dictionary instead of + # weakref.WeakKeyDictionary(). That's no problem, because deleting a ContextVar leaks + # references anyway + self.picklable_contextvars = {} + + #: If set to :data:`True`, call :func:`copyreg.pickle` to declare this *ContextReducer* as + #: "reduction" function for :class:`~contextvars.Context` objects, when the + #: :meth:`register_contextvar` is called for the first time. + self.auto_register = auto_register + + #: If set to :data:`True`, use :func:`contextvars.copy_context` to create a new + #: :class:`~contextvars.Context` object upon unpickling. This way the unpickled + #: context variables are added to the existing context variables. + self.factory_is_copy_context = factory_is_copy_context + + def __call__(self, context): + """Reduce a contextvars.Context object + """ + if not isinstance(context, contextvars.Context): + raise TypeError('Argument must be a Context object not {}'.format(type(context).__name__)) + cvars = {} + for cv, value in context.items(): + mod_and_name = self.picklable_contextvars.get(cv) + if mod_and_name is not None: + cvars[mod_and_name] = value + + if self.factory_is_copy_context: + cls = contextvars.copy_context + else: + cls = type(context) + if cls is contextvars.Context: + # class contextvars.Context can't be pickled, because its __module__ is 'builtins' (Python 3.7.5) + cls = None + return _context_factory, (cls, cvars) + + def register_contextvar(self, contextvar, module, qualname=None, *, validate=True): + """Register *contextvar* with this :class:`ContextReducer` + + Declare, that the context variable *contextvar* can be pickled. + + :param contextvar: a context variable + :type contextvar: :class:`~contextvars.ContextVar` + :param module: the module object or the module name, where *contextvar* is declared + :type module: :class:`~types.ModuleType` or :class:`str` + :param qualname: the qualified name of *contextvar* in *module*. If unset, *contextvar.name* is used. + :type qualname: :class:`str` + :param validate: if true, check that *contextvar* can be accessed as *module.qualname*. + :type validate: :class:`boolean` + :raises TypeError: if *contextvar* is not an instance of :class:`~contextvars.ContextVar` + :raises ValueError: if *contextvar* is not *module.qualname*. + """ + if not isinstance(contextvar, contextvars.ContextVar): + raise TypeError('Argument 1 must be a ContextVar object not {}'.format(type(contextvar).__name__)) + + modulename = module + is_module = isinstance(module, types.ModuleType) + if is_module: + modulename = module.__name__ + if qualname is None: + qualname = contextvar.name + if validate: + if not is_module: + module = importlib.import_module(modulename) + v = _getattribute(module, qualname)[0] # raises AttributeError + if v is not contextvar: + raise ValueError('Not the same object: ContextVar {} and global {}.{}'.format(contextvar.name, modulename, qualname)) + self.picklable_contextvars[contextvar] = (modulename, qualname) + if self.auto_register: + self.auto_register = False + copyreg.pickle(contextvars.Context, self) + # in case of stackless python enable context pickling + try: + from stackless import PICKLEFLAGS_PICKLE_CONTEXT, pickle_flags, pickle_flags_default + except ImportError: + pass + else: + pickle_flags(PICKLEFLAGS_PICKLE_CONTEXT, PICKLEFLAGS_PICKLE_CONTEXT) + pickle_flags_default(PICKLEFLAGS_PICKLE_CONTEXT, PICKLEFLAGS_PICKLE_CONTEXT) + + def deregister_contextvar(self, contextvar): + """Deregister *contextvar* from this :class:`ContextReducer` + + Declare, that the context variable *contextvar* can't be pickled. + + :param contextvar: a context variable + :type contextvar: :class:`~contextvars.ContextVar` + :raises KeyError: if *contextvar* hasn't been registered. + """ + del self.picklable_contextvars[contextvar] + + +#: A global :class:`ContextReducer` object. +#: +#: The attributes are set as follows +#: +#: * :attr:`~ContextReducer.auto_register`: :data:`True` +#: * :attr:`~ContextReducer.factory_is_copy_context`: :data:`True` +#: +#: :meta hide-value: +#: +global_context_reducer = ContextReducer(auto_register=True, factory_is_copy_context=True) + + +def register_contextvar(contextvar, module, qualname=None, *, validate=True): + """Register *contextvar* with :data:`global_context_reducer` + + See :meth:`ContextReducer.register_contextvar`. + """ + return global_context_reducer.register_contextvar(contextvar, module, qualname, validate=validate) + + +def deregister_contextvar(contextvar): + """Deregister *contextvar* from :data:`global_context_reducer` + + See :meth:`ContextReducer.deregister_contextvar`. + """ + return global_context_reducer.deregister_contextvar(contextvar) diff --git a/comfy/distributed/executors.py b/comfy/distributed/executors.py index 97f1a4798..5275e76c0 100644 --- a/comfy/distributed/executors.py +++ b/comfy/distributed/executors.py @@ -1,7 +1,4 @@ -import asyncio -import concurrent import contextvars -import threading import typing from concurrent.futures import Future, ThreadPoolExecutor from functools import partial @@ -20,7 +17,4 @@ class ContextVarExecutor(ThreadPoolExecutor): class ContextVarProcessPoolExecutor(ProcessPoolExecutor): - - def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: - # TODO: serialize the "comfyui_execution_context" - pass + pass diff --git a/comfy/distributed/process_pool_executor.py b/comfy/distributed/process_pool_executor.py index 504650bb2..96f86ac77 100644 --- a/comfy/distributed/process_pool_executor.py +++ b/comfy/distributed/process_pool_executor.py @@ -1,15 +1,28 @@ import concurrent.futures -from typing import Callable +import contextvars +import multiprocessing +import pickle +from functools import partial +from typing import Callable, Any from pebble import ProcessPool, ProcessFuture -from ..component_model.executor_types import Executor, ExecutePromptArgs +from ..component_model.executor_types import Executor + + +def _wrap_with_context(context_data: bytes, func: Callable, *args, **kwargs) -> Any: + new_ctx: contextvars.Context = pickle.loads(context_data) + return new_ctx.run(func, *args, **kwargs) class ProcessPoolExecutor(ProcessPool, Executor): - def __init__(self, max_workers: int = 1): - super().__init__(max_workers=1) - + def __init__(self, + max_workers: int = 1, + max_tasks: int = 0, + initializer: Callable = None, + initargs: list | tuple = (), + context: multiprocessing.context.BaseContext = multiprocessing): + super().__init__(max_workers=max_workers, max_tasks=max_tasks, initializer=initializer, initargs=initargs, context=context) def shutdown(self, wait=True, *, cancel_futures=False): if cancel_futures: @@ -21,19 +34,16 @@ class ProcessPoolExecutor(ProcessPool, Executor): return def schedule(self, function: Callable, - args: list = (), + args: list | tuple = (), kwargs=None, timeout: float = None) -> ProcessFuture: - # todo: restart worker when there is insufficient VRAM or the workflows are sufficiently different - # try: - # args: ExecutePromptArgs - # prompt, prompt_id, client_id, span_context, progress_handler, configuration = args - # - # except ValueError: - # pass if kwargs is None: kwargs = {} - return super().schedule(function, args, kwargs, timeout) + + context_bin = pickle.dumps(contextvars.copy_context()) + unpack_context_then_run_function = partial(_wrap_with_context, context_bin, function) + + return super().schedule(unpack_context_then_run_function, args=args, kwargs=kwargs, timeout=timeout) def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: return self.schedule(fn, args=list(args), kwargs=kwargs, timeout=None) diff --git a/comfy/execution_context.py b/comfy/execution_context.py index 6bf66c2f6..4c1b2a656 100644 --- a/comfy/execution_context.py +++ b/comfy/execution_context.py @@ -5,36 +5,41 @@ from contextvars import ContextVar from dataclasses import dataclass, replace from typing import Optional, Final +from .component_model import cvpickle from .component_model.executor_types import ExecutorToClientProgress from .component_model.folder_path_types import FolderNames from .distributed.server_stub import ServerStub +from .nodes.package_typing import ExportedNodes, exported_nodes_view -_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context") +comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context") +# enables context var propagation across process boundaries for process pool executors +cvpickle.register_contextvar(comfyui_execution_context, __name__) @dataclass(frozen=True) class ExecutionContext: server: ExecutorToClientProgress folder_names_and_paths: FolderNames + custom_nodes: ExportedNodes node_id: Optional[str] = None task_id: Optional[str] = None inference_mode: bool = True -_current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True))) +comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes())) def current_execution_context() -> ExecutionContext: - return _current_context.get() + return comfyui_execution_context.get() @contextmanager def _new_execution_context(ctx: ExecutionContext): - token = _current_context.set(ctx) + token = comfyui_execution_context.set(ctx) try: yield ctx finally: - _current_context.reset(token) + comfyui_execution_context.reset(token) @contextmanager @@ -59,3 +64,24 @@ def context_execute_node(node_id: str): new_ctx = replace(current_ctx, node_id=node_id) with _new_execution_context(new_ctx): yield new_ctx + + +@contextmanager +def context_add_custom_nodes(exported_nodes: ExportedNodes): + """ + Adds custom nodes to the execution context + :param exported_nodes: an object that represents a gathering of custom node export symbols + :return: a context + """ + current_ctx = current_execution_context() + if len(exported_nodes) == 0: + yield current_ctx + + if len(current_ctx.custom_nodes) == 0: + merged_custom_nodes = exported_nodes + else: + merged_custom_nodes = exported_nodes_view(current_ctx.custom_nodes, exported_nodes) + + new_ctx = replace(current_ctx, custom_nodes=merged_custom_nodes) + with _new_execution_context(new_ctx): + yield new_ctx \ No newline at end of file diff --git a/comfy/graph.py b/comfy/graph.py index a24f134ae..e8d8c90f1 100644 --- a/comfy/graph.py +++ b/comfy/graph.py @@ -1,9 +1,9 @@ from typing import Optional -from .cmd.execution import nodes from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ DependencyExecutionErrorMessage from .graph_utils import is_link +from .nodes_context import get_nodes class DynamicPrompt: @@ -82,7 +82,7 @@ class TopologicalSort: def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] return get_input_info(class_def, input_name) def make_input_strong_link(self, to_node_id, to_input): @@ -197,7 +197,7 @@ class ExecutionList(TopologicalSort): # Some other heuristics could probably be used here to improve the UX further. def is_output(node_id): class_type = self.dynprompt.get_node(node_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: return True return False diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index d83c37830..3fbb50ab5 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections import ChainMap from dataclasses import dataclass, field from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ - Callable, List, Type, runtime_checkable + Callable, List, Type, MutableMapping from typing_extensions import TypedDict, NotRequired @@ -159,9 +160,9 @@ class CustomNode(Protocol): @dataclass class ExportedNodes: - NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict) - NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict) - EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict) + NODE_CLASS_MAPPINGS: MutableMapping[str, CustomNode] = field(default_factory=dict) + NODE_DISPLAY_NAME_MAPPINGS: MutableMapping[str, str] = field(default_factory=dict) + EXTENSION_WEB_DIRS: MutableMapping[str, str] = field(default_factory=dict) def update(self, exported_nodes: ExportedNodes) -> ExportedNodes: self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS) @@ -187,3 +188,28 @@ class ExportedNodes: def __add__(self, other): exported_nodes = ExportedNodes().update(self) return exported_nodes.update(other) + + +class _ExportedNodesAsChainMap(ExportedNodes): + NODE_CLASS_MAPPINGS: ChainMap[str, CustomNode] = field(default_factory=ChainMap) + NODE_DISPLAY_NAME_MAPPINGS: ChainMap[str, str] = field(default_factory=ChainMap) + EXTENSION_WEB_DIRS: ChainMap[str, str] = field(default_factory=ChainMap) + + @classmethod + def from_iter(cls, *exported_nodes: ExportedNodes): + en = _ExportedNodesAsChainMap() + en.NODE_CLASS_MAPPINGS = ChainMap(*[ncm.NODE_CLASS_MAPPINGS for ncm in exported_nodes]) + en.NODE_DISPLAY_NAME_MAPPINGS = ChainMap(*[ncm.NODE_DISPLAY_NAME_MAPPINGS for ncm in exported_nodes]) + en.EXTENSION_WEB_DIRS = ChainMap(*[ncm.EXTENSION_WEB_DIRS for ncm in exported_nodes]) + return en + + def update(self, exported_nodes: ExportedNodes) -> ExportedNodes: + self.NODE_CLASS_MAPPINGS = self.NODE_CLASS_MAPPINGS.new_child(exported_nodes.NODE_CLASS_MAPPINGS) + self.NODE_DISPLAY_NAME_MAPPINGS = self.NODE_DISPLAY_NAME_MAPPINGS.new_child(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS) + self.EXTENSION_WEB_DIRS = self.EXTENSION_WEB_DIRS.new_child(exported_nodes.EXTENSION_WEB_DIRS) + return self + + +def exported_nodes_view(*exported_nodes: ExportedNodes) -> ExportedNodes: + """Gets a view of all the provided exported nodes, concatenating them together using a ChainMap internally""" + return _ExportedNodesAsChainMap.from_iter(*exported_nodes) diff --git a/comfy/nodes_context.py b/comfy/nodes_context.py new file mode 100644 index 000000000..f68d71133 --- /dev/null +++ b/comfy/nodes_context.py @@ -0,0 +1,15 @@ +# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive +import lazy_object_proxy + +from comfy.execution_context import current_execution_context +from comfy.nodes.package import import_all_nodes_in_workspace +from comfy.nodes.package_typing import ExportedNodes, exported_nodes_view + +nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) + + +def get_nodes() -> ExportedNodes: + current_ctx = current_execution_context() + if len(current_ctx.custom_nodes) == 0: + return nodes + return exported_nodes_view(nodes, current_ctx.custom_nodes) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 29ae9c637..a33f20e62 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -13,6 +13,7 @@ from comfy.client.embedded_comfy_client import EmbeddedComfyClient from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, ExecutingMessage, ExecutionErrorMessage, \ DependencyCycleError from comfy.distributed.server_stub import ServerStub +from comfy.execution_context import context_add_custom_nodes from comfy.graph_utils import GraphBuilder, Node from comfy.nodes.package_typing import ExportedNodes @@ -108,16 +109,15 @@ class TestExecution: (100,), ]) async def client(self, request) -> ComfyClient: - from comfy.cmd.execution import nodes from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS - nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)) lru_size, = request.param configuration = default_configuration() configuration.cache_lru = lru_size progress_handler = _ProgressHandler() - async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client: - yield ComfyClient(embedded_client, progress_handler) + with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)): + async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client: + yield ComfyClient(embedded_client, progress_handler) @fixture def builder(self, request): diff --git a/tests/unit/test_panics.py b/tests/unit/test_panics.py index 3966db7ec..7d195b689 100644 --- a/tests/unit/test_panics.py +++ b/tests/unit/test_panics.py @@ -7,9 +7,11 @@ import torch from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient -from comfy.cmd.execution import nodes from comfy.component_model.make_mutable import make_mutable from comfy.component_model.tensor_types import RGBImageBatch +from comfy.distributed.executors import ContextVarExecutor +from comfy.distributed.process_pool_executor import ProcessPoolExecutor +from comfy.execution_context import context_add_custom_nodes from comfy.nodes.package_typing import CustomNode, ExportedNodes @@ -80,7 +82,7 @@ async def test_separate_thread_callback(): # Custom test exception that we'll configure to panic on -class TestUnrecoverableError(Exception): +class UnrecoverableError(Exception): pass @@ -102,7 +104,7 @@ class TestExceptionNode(CustomNode): def raise_exception(self, should_raise=True) -> tuple[RGBImageBatch]: if should_raise: - raise TestUnrecoverableError("Test exception from node") + raise UnrecoverableError("Test exception from node") else: # Return a dummy image if not raising return (torch.zeros([1, 64, 64, 3]),) @@ -117,6 +119,11 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestExceptionNode": "Test Exception Node", } +EXECUTOR_FACTORIES = [ + (ContextVarExecutor, {"max_workers": 1}), + (ProcessPoolExecutor, {"max_workers": 1}), +] + def create_failing_workflow(): """Create a workflow that uses our test node to raise an exception""" @@ -131,52 +138,113 @@ def create_failing_workflow(): @pytest.mark.asyncio -async def test_panic_on_exception(): - # Set up the test nodes - nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)) - +@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES) +async def test_panic_on_exception_with_executor(executor_cls, executor_kwargs): + """Test panic behavior with different executor types""" # Create configuration with our test exception in panic_when config = Configuration() - config.panic_when = [f"{__name__}.TestUnrecoverableError"] + config.panic_when = [f"{__name__}.UnrecoverableError"] + + # Initialize the specific executor + executor = executor_cls(**executor_kwargs) # Mock sys.exit to prevent actual exit and verify it's called - with patch('sys.exit') as mock_exit: + with (context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)), + patch('sys.exit') as mock_exit): try: - async with EmbeddedComfyClient(configuration=config) as client: + async with EmbeddedComfyClient(configuration=config, executor=executor) as client: # Queue our failing workflow await client.queue_prompt(create_failing_workflow()) - except TestUnrecoverableError: + except SystemExit: + sys_exit_called = True + except UnrecoverableError: # We expect the exception to be raised here - pass + sys_exit_called = False # Give the event loop a chance to process the exit callback await asyncio.sleep(0) # Verify sys.exit was called with code 1 - mock_exit.assert_called_once_with(1) + if executor_cls == ProcessPoolExecutor: + assert sys_exit_called + else: + mock_exit.assert_called_once_with(1) @pytest.mark.asyncio -async def test_no_panic_when_disabled(): - """Verify that the same exception doesn't trigger exit when not in panic_when""" - # Set up the test nodes - nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)) +@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES) +async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwargs): + """Test no-panic behavior with different executor types""" # Create configuration without the exception in panic_when config = Configuration() + # Initialize the specific executor + executor = executor_cls(**executor_kwargs) + # Mock sys.exit to verify it's not called - with patch('sys.exit') as mock_exit: + with (context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)), + patch('sys.exit') as mock_exit): try: - async with EmbeddedComfyClient(configuration=config) as client: + async with EmbeddedComfyClient(configuration=config, executor=executor) as client: # Queue our failing workflow await client.queue_prompt(create_failing_workflow()) - except TestUnrecoverableError: + except SystemExit: + sys_exit_called = True + except UnrecoverableError: # We expect the exception to be raised here - pass + sys_exit_called = False # Give the event loop a chance to process any callbacks - await asyncio.sleep(0.1) + await asyncio.sleep(0) # Verify sys.exit was not called mock_exit.assert_not_called() + assert not sys_exit_called + + +@pytest.mark.asyncio +@pytest.mark.parametrize("executor_cls,executor_kwargs", EXECUTOR_FACTORIES) +async def test_executor_cleanup(executor_cls, executor_kwargs): + """Test that executors are properly cleaned up after use""" + executor = executor_cls(**executor_kwargs) + + with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)): + async with EmbeddedComfyClient(executor=executor) as client: + # Create a simple workflow that doesn't raise + workflow = create_failing_workflow() + workflow["1"]["inputs"]["should_raise"] = False + + # Run it + result = await client.queue_prompt(workflow) + assert isinstance(result, dict), "Expected workflow to return results" + + +# Add a test for parallel execution to verify multi-worker behavior +@pytest.mark.asyncio +@pytest.mark.parametrize("executor_cls,executor_kwargs", [ + (ContextVarExecutor, {"max_workers": 2}), + (ProcessPoolExecutor, {"max_workers": 2}), +]) +async def test_parallel_execution(executor_cls, executor_kwargs): + """Test that executors can handle multiple workflows in parallel""" + executor = executor_cls(**executor_kwargs) + + with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=TEST_NODE_CLASS_MAPPINGS, + NODE_DISPLAY_NAME_MAPPINGS=TEST_NODE_DISPLAY_NAME_MAPPINGS)): + async with EmbeddedComfyClient(executor=executor) as client: + # Create multiple non-failing workflows + workflow = create_failing_workflow() + workflow["1"]["inputs"]["should_raise"] = False + + # Run multiple workflows concurrently + results = await asyncio.gather(*[ + client.queue_prompt(workflow) + for _ in range(3) + ]) + + assert len(results) == 3, "Expected all workflows to complete" + assert all(isinstance(r, dict) for r in results), "Expected all workflows to return results" diff --git a/tests/unit/test_process_pool.py b/tests/unit/test_process_pool.py new file mode 100644 index 000000000..e7a3773e7 --- /dev/null +++ b/tests/unit/test_process_pool.py @@ -0,0 +1,29 @@ +import contextvars + +import pytest + +from comfy.component_model import cvpickle +from comfy.distributed.process_pool_executor import ProcessPoolExecutor + +# Example context variable +my_var = contextvars.ContextVar('my_var', default=None) +cvpickle.register_contextvar(my_var, module=__name__) + + +def worker_function(): + """Function that runs in worker process and accesses context""" + return my_var.get() + + +@pytest.mark.asyncio +async def test_context_preservation(): + # Set context in parent + my_var.set("test_value") + + # Create pool and submit work + with ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(worker_function) + result = future.result() + + # Verify context was preserved + assert result == "test_value" diff --git a/tests/unit/test_validation.py b/tests/unit/test_validation.py index 1cc879dc4..536cfe314 100644 --- a/tests/unit/test_validation.py +++ b/tests/unit/test_validation.py @@ -5,7 +5,8 @@ import pytest from pytest_mock import MockerFixture from comfy.cli_args import args -from comfy.cmd.execution import validate_prompt, nodes +from comfy.cmd.execution import validate_prompt +from comfy.nodes_context import nodes valid_prompt: Final[dict] = { "1": {