Fixes to tests and configuration, making library use more durable

This commit is contained in:
doctorpangloss 2025-10-23 19:46:40 -07:00
parent 67f9d3e693
commit 058e5dc634
17 changed files with 182 additions and 84 deletions

View File

@ -9,15 +9,17 @@ from typing import Optional
import configargparse as argparse import configargparse as argparse
from . import __version__ from . import __version__
from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \ from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction
from .component_model.module_property import create_module_properties
# todo: move this # todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_module_properties = create_module_properties()
def _create_parser() -> EnhancedConfigArgParser: def _create_parser() -> EnhancedConfigArgParser:
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'], parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
@ -28,7 +30,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument('-w', "--cwd", type=str, default=None, parser.add_argument('-w', "--cwd", type=str, default=None,
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.") help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.") parser.add_argument("--base-paths", type=str, nargs='+', default=[], action=FlattenAndAppendAction, help="Additional base paths for custom nodes, models and inputs.")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
@ -149,8 +151,8 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.") parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--blacklist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
@ -208,7 +210,8 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument( parser.add_argument(
'--panic-when', '--panic-when',
action='append', action=FlattenAndAppendAction,
nargs='+',
help=""" help="""
List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it. List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it.
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
@ -270,7 +273,7 @@ def _create_parser() -> EnhancedConfigArgParser:
default_db_url = db_config() default_db_url = db_config()
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.") parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
# now give plugins a chance to add configuration # now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'): for entry_point in entry_points().select(group='comfyui.custom_config'):
@ -317,4 +320,19 @@ def default_configuration() -> Configuration:
return _parse_args(_create_parser()) return _parse_args(_create_parser())
args = _parse_args(args_parsing=options.args_parsing) def cli_args_configuration() -> Configuration:
return _parse_args(args_parsing=True)
@_module_properties.getter
def _args() -> Configuration:
from .execution_context import current_execution_context
return current_execution_context().configuration
__all__ = [
"args", # pylint: disable=undefined-all-variable, type: Configuration
"default_configuration",
"cli_args_configuration",
"DEFAULT_VERSION_STRING"
]

View File

@ -29,7 +29,7 @@ from ..distributed.executors import ContextVarExecutor
from ..distributed.history import History from ..distributed.history import History
from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
from ..execution_context import current_execution_context from ..execution_context import current_execution_context, context_configuration
_prompt_executor = threading.local() _prompt_executor = threading.local()
@ -57,7 +57,6 @@ def _execute_prompt(
finally: finally:
detach(token) detach(token)
async def __execute_prompt( async def __execute_prompt(
prompt: dict, prompt: dict,
prompt_id: str, prompt_id: str,
@ -66,7 +65,16 @@ async def __execute_prompt(
progress_handler: ExecutorToClientProgress | None, progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None, configuration: Configuration | None,
partial_execution_targets: list[str] | None) -> dict: partial_execution_targets: list[str] | None) -> dict:
from .. import options with context_configuration(configuration):
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
async def ___execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: Context,
progress_handler: ExecutorToClientProgress | None,
partial_execution_targets: list[str] | None) -> dict:
from ..cmd.execution import PromptExecutor from ..cmd.execution import PromptExecutor
progress_handler = progress_handler or ServerStub() progress_handler = progress_handler or ServerStub()
@ -74,13 +82,6 @@ async def __execute_prompt(
try: try:
prompt_executor: PromptExecutor = _prompt_executor.executor prompt_executor: PromptExecutor = _prompt_executor.executor
except (LookupError, AttributeError): except (LookupError, AttributeError):
if configuration is None:
options.enable_args_parsing()
else:
from ..cmd.main_pre import args
args.clear()
args.update(configuration)
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context): with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
# todo: deal with new caching features # todo: deal with new caching features
prompt_executor = PromptExecutor(progress_handler) prompt_executor = PromptExecutor(progress_handler)
@ -117,6 +118,7 @@ async def __execute_prompt(
def _cleanup(): def _cleanup():
from ..cmd.execution import PromptExecutor from ..cmd.execution import PromptExecutor
from ..nodes_context import invalidate
try: try:
prompt_executor: PromptExecutor = _prompt_executor.executor prompt_executor: PromptExecutor = _prompt_executor.executor
# this should clear all references to output tensors and make it easier to collect back the memory # this should clear all references to output tensors and make it easier to collect back the memory
@ -130,6 +132,10 @@ def _cleanup():
model_management.soft_empty_cache() model_management.soft_empty_cache()
except: except:
pass pass
try:
invalidate()
except:
pass
class Comfy: class Comfy:
@ -172,6 +178,8 @@ class Comfy:
self._task_count_lock = RLock() self._task_count_lock = RLock()
self._task_count = 0 self._task_count = 0
self._history = History() self._history = History()
self._context_stack = []
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@ -183,6 +191,9 @@ class Comfy:
def __enter__(self): def __enter__(self):
self._is_running = True self._is_running = True
cm = context_configuration(self._configuration)
cm.__enter__()
self._context_stack.append(cm)
return self return self
@property @property
@ -193,9 +204,13 @@ class Comfy:
get_event_loop().run_in_executor(self._executor, _cleanup) get_event_loop().run_in_executor(self._executor, _cleanup)
self._executor.shutdown(wait=True) self._executor.shutdown(wait=True)
self._is_running = False self._is_running = False
self._context_stack.pop().__exit__(*args)
async def __aenter__(self): async def __aenter__(self):
self._is_running = True self._is_running = True
cm = context_configuration(self._configuration)
cm.__enter__()
self._context_stack.append(cm)
return self return self
async def __aexit__(self, *args): async def __aexit__(self, *args):
@ -207,6 +222,7 @@ class Comfy:
self._executor.shutdown(wait=True) self._executor.shutdown(wait=True)
self._is_running = False self._is_running = False
self._context_stack.pop().__exit__(*args)
async def queue_prompt_api(self, async def queue_prompt_api(self,
prompt: PromptDict | str | dict, prompt: PromptDict | str | dict,

View File

@ -30,9 +30,9 @@ from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io from comfy_api.latest import io
from ..execution_context import current_execution_context
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
@ -675,7 +675,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
"current_inputs": input_data_formatted "current_inputs": input_data_formatted
} }
if should_panic_on_exception(ex, args.panic_when): if should_panic_on_exception(ex, current_execution_context().configuration.panic_when):
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
def sys_exit(*args): def sys_exit(*args):

View File

@ -148,6 +148,13 @@ def setup_database():
async def _start_comfyui(from_script_dir: Optional[Path] = None): async def _start_comfyui(from_script_dir: Optional[Path] = None):
from ..execution_context import context_configuration
from ..cli_args import cli_args_configuration
with context_configuration(cli_args_configuration()):
await __start_comfyui(from_script_dir=from_script_dir)
async def __start_comfyui(from_script_dir: Optional[Path] = None):
""" """
Runs ComfyUI's frontend and backend like upstream. Runs ComfyUI's frontend and backend like upstream.
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path :param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path

View File

@ -18,6 +18,7 @@ import fsspec
from .. import options from .. import options
from ..app import logger from ..app import logger
from ..cli_args_types import Configuration
from ..component_model import package_filesystem from ..component_model import package_filesystem
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
@ -164,6 +165,7 @@ def _register_fsspec_fs():
package_filesystem.PkgResourcesFileSystem, package_filesystem.PkgResourcesFileSystem,
) )
args: Configuration
_configure_logging() _configure_logging()
_fix_pytorch_240() _fix_pytorch_240()

View File

@ -2,6 +2,17 @@ import sys
from functools import wraps from functools import wraps
def create_module_properties(): def create_module_properties():
"""
Example:
>>> _module_properties = create_module_properties()
>>> @_module_properties.getter
>>> def _nodes():
>>> return ...
This creates nodes as a property
:return:
"""
properties = {} properties = {}
patched_modules = set() patched_modules = set()

View File

@ -5,6 +5,8 @@ from contextvars import ContextVar
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from typing import Optional, Final from typing import Optional, Final
from .cli_args import cli_args_configuration
from .cli_args_types import Configuration
from .component_model import cvpickle from .component_model import cvpickle
from .component_model.executor_types import ExecutorToClientProgress from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames from .component_model.folder_path_types import FolderNames
@ -23,6 +25,7 @@ class ExecutionContext:
list_index: Optional[int] = None list_index: Optional[int] = None
inference_mode: bool = True inference_mode: bool = True
progress_registry: Optional[AbstractProgressRegistry] = None progress_registry: Optional[AbstractProgressRegistry] = None
configuration: Optional[Configuration] = None
def __iter__(self): def __iter__(self):
""" """
@ -34,7 +37,7 @@ class ExecutionContext:
yield self.list_index yield self.list_index
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub())) comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration()))
# enables context var propagation across process boundaries for process pool executors # enables context var propagation across process boundaries for process pool executors
cvpickle.register_contextvar(comfyui_execution_context, __name__) cvpickle.register_contextvar(comfyui_execution_context, __name__)
@ -52,6 +55,17 @@ def _new_execution_context(ctx: ExecutionContext):
comfyui_execution_context.reset(token) comfyui_execution_context.reset(token)
@contextmanager
def context_configuration(configuration: Optional[Configuration] = None):
current_ctx = current_execution_context()
if configuration is None:
from .cli_args import cli_args_configuration
configuration = cli_args_configuration()
new_ctx = replace(current_ctx, configuration=configuration)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager @contextmanager
def context_folder_names_and_paths(folder_names_and_paths: FolderNames): def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
current_ctx = current_execution_context() current_ctx = current_execution_context()

View File

@ -4,10 +4,12 @@ import importlib
import logging import logging
import os import os
import pkgutil import pkgutil
import threading
import time import time
import types import types
from functools import reduce from functools import reduce
from importlib.metadata import entry_points from importlib.metadata import entry_points
from threading import RLock
from opentelemetry.trace import Span, Status, StatusCode from opentelemetry.trace import Span, Status, StatusCode
@ -17,8 +19,9 @@ from comfy_api.version_list import supported_versions
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from ..component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
from ..execution_context import current_execution_context
_nodes_available_at_startup: ExportedNodes = ExportedNodes() _nodes_local = threading.local()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,8 +49,6 @@ def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
return exported_nodes return exported_nodes
def _import_and_enumerate_nodes_in_module(module: types.ModuleType, def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
print_import_times=False, print_import_times=False,
raise_on_failure=False, raise_on_failure=False,
@ -116,7 +117,11 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
@tracer.start_as_current_span("Import All Nodes In Workspace") @tracer.start_as_current_span("Import All Nodes In Workspace")
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order # now actually import the nodes, to improve control of node loading order
from ..cli_args import args try:
_nodes_available_at_startup = _nodes_local.nodes
except (LookupError, AttributeError):
_nodes_available_at_startup = _nodes_local.nodes = ExportedNodes()
args = current_execution_context().configuration
# todo: this is some truly braindead stuff # todo: this is some truly braindead stuff
register_versions([ register_versions([
@ -126,48 +131,47 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
) for v in supported_versions ) for v in supported_versions
]) ])
# only load these nodes once _nodes_available_at_startup.clear()
if len(_nodes_available_at_startup) == 0:
# import base_nodes first # import base_nodes first
from . import base_nodes from . import base_nodes
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
base_and_extra = reduce(lambda x, y: x.update(y), base_and_extra = reduce(lambda x, y: x.update(y),
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
# this is the list of default nodes to import # this is the list of default nodes to import
base_nodes, base_nodes,
comfy_extras_nodes comfy_extras_nodes
]), ]),
ExportedNodes()) ExportedNodes())
custom_nodes_mappings = ExportedNodes() custom_nodes_mappings = ExportedNodes()
if args.disable_all_custom_nodes: if args.disable_all_custom_nodes:
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded") logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
_nodes_available_at_startup.update(base_and_extra) _nodes_available_at_startup.update(base_and_extra)
return _nodes_available_at_startup return _nodes_available_at_startup
# load from entrypoints # load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'): for entry_point in entry_points().select(group='comfyui.custom_nodes'):
# Load the module associated with the current entry point # Load the module associated with the current entry point
try: try:
module = entry_point.load() module = entry_point.load()
except ModuleNotFoundError as module_not_found_error: except ModuleNotFoundError as module_not_found_error:
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
continue continue
# Ensure that what we've loaded is indeed a module # Ensure that what we've loaded is indeed a module
if isinstance(module, types.ModuleType): if isinstance(module, types.ModuleType):
custom_nodes_mappings.update( custom_nodes_mappings.update(
_import_and_enumerate_nodes_in_module(module, print_import_times=True)) _import_and_enumerate_nodes_in_module(module, print_import_times=True))
# load the vanilla custom nodes last # load the vanilla custom nodes last
if vanilla_custom_nodes: if vanilla_custom_nodes:
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes() custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
# don't allow custom nodes to overwrite base nodes # don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra custom_nodes_mappings -= base_and_extra
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings) _nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
return _nodes_available_at_startup return _nodes_available_at_startup

View File

@ -198,6 +198,10 @@ class ExportedNodes:
def __bool__(self): def __bool__(self):
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0 return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
def clear(self):
self.NODE_CLASS_MAPPINGS.clear()
self.EXTENSION_WEB_DIRS.clear()
self.NODE_DISPLAY_NAME_MAPPINGS.clear()
class _ExportedNodesAsChainMap(ExportedNodes): class _ExportedNodesAsChainMap(ExportedNodes):
@classmethod @classmethod

View File

@ -1,15 +1,26 @@
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive # todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
import threading
import lazy_object_proxy import lazy_object_proxy
from .execution_context import current_execution_context from .execution_context import current_execution_context
from .nodes.package import import_all_nodes_in_workspace from .nodes.package import import_all_nodes_in_workspace
from .nodes.package_typing import ExportedNodes, exported_nodes_view from .nodes.package_typing import ExportedNodes, exported_nodes_view
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) _nodes_local = threading.local()
def invalidate():
_nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_nodes() -> ExportedNodes: def get_nodes() -> ExportedNodes:
current_ctx = current_execution_context() current_ctx = current_execution_context()
try:
nodes = _nodes_local.nodes
except (LookupError, AttributeError):
nodes = _nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
if len(current_ctx.custom_nodes) == 0: if len(current_ctx.custom_nodes) == 0:
return nodes return nodes
return exported_nodes_view(nodes, current_ctx.custom_nodes) return exported_nodes_view(nodes, current_ctx.custom_nodes)

View File

@ -1,6 +1,5 @@
args_parsing = True
args_parsing = False
def enable_args_parsing(enable=True): def enable_args_parsing(enable=True):
global args_parsing pass
args_parsing = enable

View File

@ -128,12 +128,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
if return_metadata: if return_metadata:
metadata = f.metadata() metadata = f.metadata()
except Exception as e: except Exception as e:
if len(e.args) > 0: message = str(e)
message = e.args[0] if "HeaderTooLarge" in message:
if "HeaderTooLarge" in message: raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.") if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message: raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
raise e raise e
elif ckpt.lower().endswith("index.json"): elif ckpt.lower().endswith("index.json"):
# from accelerate # from accelerate

View File

@ -187,7 +187,7 @@ class TestExecution:
await client.run(g) await client.run(g)
mask.inputs['value'] = 0.4 mask.inputs['value'] = 0.4
result2 = client.run(g) result2 = await client.run(g)
assert not result2.did_run(input1), "Input1 should have been cached" assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached" assert not result2.did_run(input2), "Input2 should have been cached"

View File

@ -5,6 +5,7 @@ from importlib.resources import files
import pytest import pytest
from comfy.api.components.schema.prompt import Prompt from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import Comfy from comfy.client.embedded_comfy_client import Comfy
@ -23,14 +24,17 @@ _TEST_WORKFLOW = {
async def test_respect_cwd_param(): async def test_respect_cwd_param():
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
cwd = str(tmp_dir) cwd = str(tmp_dir)
config = Configuration(cwd=cwd) config = default_configuration()
# for finding the custom nodes config.cwd = cwd
config.base_paths = [files(__package__)]
from comfy.cmd.folder_paths import models_dir from comfy.cmd.folder_paths import models_dir
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration" assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
client = Comfy(config) # for finding the custom nodes
prompt = Prompt.validate(_TEST_WORKFLOW) config.base_paths = [str(files(__package__))]
outputs = await client.queue_prompt_api(prompt)
path_as_imported = outputs.outputs["0"]["path"][0] async with Comfy(config) as client:
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory" prompt = Prompt.validate(_TEST_WORKFLOW)
outputs = await client.queue_prompt_api(prompt)
path_as_imported = outputs.outputs["0"]["path"][0]
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"

View File

@ -1,9 +1,10 @@
import pytest
from importlib.resources import files from importlib.resources import files
from comfy.api.components.schema.prompt import Prompt import pytest
from comfy.cli_args_types import Configuration
from comfy.cli_args import default_configuration
from comfy.client.embedded_comfy_client import Comfy from comfy.client.embedded_comfy_client import Comfy
from comfy.execution_context import context_configuration
_TEST_WORKFLOW_1 = { _TEST_WORKFLOW_1 = {
"0": { "0": {
@ -35,10 +36,15 @@ _TEST_WORKFLOW_2 = {
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_blacklist_node(): async def test_blacklist_node():
config = Configuration(blacklist_custom_nodes=['issue_46']) config = default_configuration()
config.blacklist_custom_nodes = ['issue_46']
# for finding the custom nodes # for finding the custom nodes
config.base_paths = [str(files(__package__))] config.base_paths = [str(files(__package__))]
with context_configuration(config):
from comfy.nodes_context import get_nodes
nodes = get_nodes()
assert "ShouldNotExist" not in nodes.NODE_CLASS_MAPPINGS
async with Comfy(config) as client: async with Comfy(config) as client:
from comfy.cmd.execution import validate_prompt from comfy.cmd.execution import validate_prompt
res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[]) res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[])

View File

@ -190,6 +190,8 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
patch('sys.exit') as mock_exit): patch('sys.exit') as mock_exit):
try: try:
async with Comfy(configuration=config, executor=executor) as client: async with Comfy(configuration=config, executor=executor) as client:
from comfy.cli_args import args
assert len(args.panic_when) == 0
# Queue our failing workflow # Queue our failing workflow
await client.queue_prompt(create_failing_workflow()) await client.queue_prompt(create_failing_workflow())
except SystemExit: except SystemExit:

View File

@ -6,7 +6,7 @@ from pytest_mock import MockerFixture
from comfy.cli_args import args from comfy.cli_args import args
from comfy.cmd.execution import validate_prompt from comfy.cmd.execution import validate_prompt
from comfy.nodes_context import nodes from comfy.nodes_context import get_nodes
import uuid import uuid
@ -75,6 +75,7 @@ known_models: ContextVar[list[str]] = ContextVar('known_models', default=[])
@pytest.fixture @pytest.fixture
def mock_nodes(mocker: MockerFixture): def mock_nodes(mocker: MockerFixture):
nodes = get_nodes()
class MockCheckpointLoaderSimple: class MockCheckpointLoaderSimple:
@staticmethod @staticmethod
def INPUT_TYPES(): def INPUT_TYPES():