diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 22ae17315..7aa1addba 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -9,15 +9,17 @@ from typing import Optional import configargparse as argparse from . import __version__ -from . import options from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \ EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction +from .component_model.module_property import create_module_properties # todo: move this DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" logger = logging.getLogger(__name__) +_module_properties = create_module_properties() + def _create_parser() -> EnhancedConfigArgParser: parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'], @@ -28,7 +30,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument('-w', "--cwd", type=str, default=None, help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.") - parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.") + parser.add_argument("--base-paths", type=str, nargs='+', default=[], action=FlattenAndAppendAction, help="Additional base paths for custom nodes, models and inputs.") parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") @@ -149,8 +151,8 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.") - parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") - parser.add_argument("--blacklist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") + parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") + parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") @@ -208,7 +210,8 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument( '--panic-when', - action='append', + action=FlattenAndAppendAction, + nargs='+', help=""" List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it. Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a @@ -270,7 +273,7 @@ def _create_parser() -> EnhancedConfigArgParser: default_db_url = db_config() parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") - parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.") + parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.") # now give plugins a chance to add configuration for entry_point in entry_points().select(group='comfyui.custom_config'): @@ -317,4 +320,19 @@ def default_configuration() -> Configuration: return _parse_args(_create_parser()) -args = _parse_args(args_parsing=options.args_parsing) +def cli_args_configuration() -> Configuration: + return _parse_args(args_parsing=True) + + +@_module_properties.getter +def _args() -> Configuration: + from .execution_context import current_execution_context + return current_execution_context().configuration + + +__all__ = [ + "args", # pylint: disable=undefined-all-variable, type: Configuration + "default_configuration", + "cli_args_configuration", + "DEFAULT_VERSION_STRING" +] diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 80445aa49..3ff4dbeae 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -29,7 +29,7 @@ from ..distributed.executors import ContextVarExecutor from ..distributed.history import History from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.server_stub import ServerStub -from ..execution_context import current_execution_context +from ..execution_context import current_execution_context, context_configuration _prompt_executor = threading.local() @@ -57,7 +57,6 @@ def _execute_prompt( finally: detach(token) - async def __execute_prompt( prompt: dict, prompt_id: str, @@ -66,7 +65,16 @@ async def __execute_prompt( progress_handler: ExecutorToClientProgress | None, configuration: Configuration | None, partial_execution_targets: list[str] | None) -> dict: - from .. import options + with context_configuration(configuration): + return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets) + +async def ___execute_prompt( + prompt: dict, + prompt_id: str, + client_id: str, + span_context: Context, + progress_handler: ExecutorToClientProgress | None, + partial_execution_targets: list[str] | None) -> dict: from ..cmd.execution import PromptExecutor progress_handler = progress_handler or ServerStub() @@ -74,13 +82,6 @@ async def __execute_prompt( try: prompt_executor: PromptExecutor = _prompt_executor.executor except (LookupError, AttributeError): - if configuration is None: - options.enable_args_parsing() - else: - from ..cmd.main_pre import args - args.clear() - args.update(configuration) - with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context): # todo: deal with new caching features prompt_executor = PromptExecutor(progress_handler) @@ -117,6 +118,7 @@ async def __execute_prompt( def _cleanup(): from ..cmd.execution import PromptExecutor + from ..nodes_context import invalidate try: prompt_executor: PromptExecutor = _prompt_executor.executor # this should clear all references to output tensors and make it easier to collect back the memory @@ -130,6 +132,10 @@ def _cleanup(): model_management.soft_empty_cache() except: pass + try: + invalidate() + except: + pass class Comfy: @@ -172,6 +178,8 @@ class Comfy: self._task_count_lock = RLock() self._task_count = 0 self._history = History() + self._context_stack = [] + @property def is_running(self) -> bool: @@ -183,6 +191,9 @@ class Comfy: def __enter__(self): self._is_running = True + cm = context_configuration(self._configuration) + cm.__enter__() + self._context_stack.append(cm) return self @property @@ -193,9 +204,13 @@ class Comfy: get_event_loop().run_in_executor(self._executor, _cleanup) self._executor.shutdown(wait=True) self._is_running = False + self._context_stack.pop().__exit__(*args) async def __aenter__(self): self._is_running = True + cm = context_configuration(self._configuration) + cm.__enter__() + self._context_stack.append(cm) return self async def __aexit__(self, *args): @@ -207,6 +222,7 @@ class Comfy: self._executor.shutdown(wait=True) self._is_running = False + self._context_stack.pop().__exit__(*args) async def queue_prompt_api(self, prompt: PromptDict | str | dict, diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index f05efaaff..0f9700995 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -30,9 +30,9 @@ from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io +from ..execution_context import current_execution_context from .. import interruption from .. import model_management -from ..cli_args import args from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ @@ -675,7 +675,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra "current_inputs": input_data_formatted } - if should_panic_on_exception(ex, args.panic_when): + if should_panic_on_exception(ex, current_execution_context().configuration.panic_when): logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") def sys_exit(*args): diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index c5748964e..731c87136 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -148,6 +148,13 @@ def setup_database(): async def _start_comfyui(from_script_dir: Optional[Path] = None): + from ..execution_context import context_configuration + from ..cli_args import cli_args_configuration + with context_configuration(cli_args_configuration()): + await __start_comfyui(from_script_dir=from_script_dir) + + +async def __start_comfyui(from_script_dir: Optional[Path] = None): """ Runs ComfyUI's frontend and backend like upstream. :param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 3bc3e9f0d..993823442 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -18,6 +18,7 @@ import fsspec from .. import options from ..app import logger +from ..cli_args_types import Configuration from ..component_model import package_filesystem os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1' @@ -164,6 +165,7 @@ def _register_fsspec_fs(): package_filesystem.PkgResourcesFileSystem, ) +args: Configuration _configure_logging() _fix_pytorch_240() diff --git a/comfy/component_model/module_property.py b/comfy/component_model/module_property.py index 3d6c18b5e..fa15823e7 100644 --- a/comfy/component_model/module_property.py +++ b/comfy/component_model/module_property.py @@ -2,6 +2,17 @@ import sys from functools import wraps def create_module_properties(): + """ + Example: + >>> _module_properties = create_module_properties() + + >>> @_module_properties.getter + >>> def _nodes(): + >>> return ... + + This creates nodes as a property + :return: + """ properties = {} patched_modules = set() diff --git a/comfy/execution_context.py b/comfy/execution_context.py index bfd34d5ea..400764831 100644 --- a/comfy/execution_context.py +++ b/comfy/execution_context.py @@ -5,6 +5,8 @@ from contextvars import ContextVar from dataclasses import dataclass, replace from typing import Optional, Final +from .cli_args import cli_args_configuration +from .cli_args_types import Configuration from .component_model import cvpickle from .component_model.executor_types import ExecutorToClientProgress from .component_model.folder_path_types import FolderNames @@ -23,6 +25,7 @@ class ExecutionContext: list_index: Optional[int] = None inference_mode: bool = True progress_registry: Optional[AbstractProgressRegistry] = None + configuration: Optional[Configuration] = None def __iter__(self): """ @@ -34,7 +37,7 @@ class ExecutionContext: yield self.list_index -comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub())) +comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration())) # enables context var propagation across process boundaries for process pool executors cvpickle.register_contextvar(comfyui_execution_context, __name__) @@ -52,6 +55,17 @@ def _new_execution_context(ctx: ExecutionContext): comfyui_execution_context.reset(token) +@contextmanager +def context_configuration(configuration: Optional[Configuration] = None): + current_ctx = current_execution_context() + if configuration is None: + from .cli_args import cli_args_configuration + configuration = cli_args_configuration() + new_ctx = replace(current_ctx, configuration=configuration) + with _new_execution_context(new_ctx): + yield new_ctx + + @contextmanager def context_folder_names_and_paths(folder_names_and_paths: FolderNames): current_ctx = current_execution_context() diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index a6dd8f8c6..500ef7286 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -4,10 +4,12 @@ import importlib import logging import os import pkgutil +import threading import time import types from functools import reduce from importlib.metadata import entry_points +from threading import RLock from opentelemetry.trace import Span, Status, StatusCode @@ -17,8 +19,9 @@ from comfy_api.version_list import supported_versions from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports from .package_typing import ExportedNodes from ..component_model.files import get_package_as_path +from ..execution_context import current_execution_context -_nodes_available_at_startup: ExportedNodes = ExportedNodes() +_nodes_local = threading.local() logger = logging.getLogger(__name__) @@ -46,8 +49,6 @@ def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes: return exported_nodes - - def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import_times=False, raise_on_failure=False, @@ -116,7 +117,11 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, @tracer.start_as_current_span("Import All Nodes In Workspace") def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: # now actually import the nodes, to improve control of node loading order - from ..cli_args import args + try: + _nodes_available_at_startup = _nodes_local.nodes + except (LookupError, AttributeError): + _nodes_available_at_startup = _nodes_local.nodes = ExportedNodes() + args = current_execution_context().configuration # todo: this is some truly braindead stuff register_versions([ @@ -126,48 +131,47 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa ) for v in supported_versions ]) - # only load these nodes once - if len(_nodes_available_at_startup) == 0: + _nodes_available_at_startup.clear() - # import base_nodes first - from . import base_nodes - from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used - from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes + # import base_nodes first + from . import base_nodes + from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used + from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes - base_and_extra = reduce(lambda x, y: x.update(y), - map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ - # this is the list of default nodes to import - base_nodes, - comfy_extras_nodes - ]), - ExportedNodes()) - custom_nodes_mappings = ExportedNodes() + base_and_extra = reduce(lambda x, y: x.update(y), + map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ + # this is the list of default nodes to import + base_nodes, + comfy_extras_nodes + ]), + ExportedNodes()) + custom_nodes_mappings = ExportedNodes() - if args.disable_all_custom_nodes: - logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded") - _nodes_available_at_startup.update(base_and_extra) - return _nodes_available_at_startup + if args.disable_all_custom_nodes: + logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded") + _nodes_available_at_startup.update(base_and_extra) + return _nodes_available_at_startup - # load from entrypoints - for entry_point in entry_points().select(group='comfyui.custom_nodes'): - # Load the module associated with the current entry point - try: - module = entry_point.load() - except ModuleNotFoundError as module_not_found_error: - logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) - continue + # load from entrypoints + for entry_point in entry_points().select(group='comfyui.custom_nodes'): + # Load the module associated with the current entry point + try: + module = entry_point.load() + except ModuleNotFoundError as module_not_found_error: + logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) + continue - # Ensure that what we've loaded is indeed a module - if isinstance(module, types.ModuleType): - custom_nodes_mappings.update( - _import_and_enumerate_nodes_in_module(module, print_import_times=True)) + # Ensure that what we've loaded is indeed a module + if isinstance(module, types.ModuleType): + custom_nodes_mappings.update( + _import_and_enumerate_nodes_in_module(module, print_import_times=True)) - # load the vanilla custom nodes last - if vanilla_custom_nodes: - custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes() + # load the vanilla custom nodes last + if vanilla_custom_nodes: + custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes() - # don't allow custom nodes to overwrite base nodes - custom_nodes_mappings -= base_and_extra + # don't allow custom nodes to overwrite base nodes + custom_nodes_mappings -= base_and_extra - _nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings) + _nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings) return _nodes_available_at_startup diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 70bc94111..7b51fdf8a 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -198,6 +198,10 @@ class ExportedNodes: def __bool__(self): return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0 + def clear(self): + self.NODE_CLASS_MAPPINGS.clear() + self.EXTENSION_WEB_DIRS.clear() + self.NODE_DISPLAY_NAME_MAPPINGS.clear() class _ExportedNodesAsChainMap(ExportedNodes): @classmethod diff --git a/comfy/nodes_context.py b/comfy/nodes_context.py index e9e08dd6f..800f9fb2d 100644 --- a/comfy/nodes_context.py +++ b/comfy/nodes_context.py @@ -1,15 +1,26 @@ # todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive +import threading + import lazy_object_proxy from .execution_context import current_execution_context from .nodes.package import import_all_nodes_in_workspace from .nodes.package_typing import ExportedNodes, exported_nodes_view -nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) +_nodes_local = threading.local() + + +def invalidate(): + _nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) def get_nodes() -> ExportedNodes: current_ctx = current_execution_context() + try: + nodes = _nodes_local.nodes + except (LookupError, AttributeError): + nodes = _nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) + if len(current_ctx.custom_nodes) == 0: return nodes return exported_nodes_view(nodes, current_ctx.custom_nodes) diff --git a/comfy/options.py b/comfy/options.py index f7f8af41e..62d588cca 100644 --- a/comfy/options.py +++ b/comfy/options.py @@ -1,6 +1,5 @@ +args_parsing = True -args_parsing = False def enable_args_parsing(enable=True): - global args_parsing - args_parsing = enable + pass diff --git a/comfy/utils.py b/comfy/utils.py index 5800e002a..67c02f2c2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -128,12 +128,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal if return_metadata: metadata = f.metadata() except Exception as e: - if len(e.args) > 0: - message = e.args[0] - if "HeaderTooLarge" in message: - raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.") - if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message: - raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.") + message = str(e) + if "HeaderTooLarge" in message: + raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.") + if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message: + raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.") raise e elif ckpt.lower().endswith("index.json"): # from accelerate diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 25d5e86ae..b7d7ad2ca 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -187,7 +187,7 @@ class TestExecution: await client.run(g) mask.inputs['value'] = 0.4 - result2 = client.run(g) + result2 = await client.run(g) assert not result2.did_run(input1), "Input1 should have been cached" assert not result2.did_run(input2), "Input2 should have been cached" diff --git a/tests/issues/test_25_respect_cwd_param.py b/tests/issues/test_25_respect_cwd_param.py index 5270c7bc0..181d9e9f4 100644 --- a/tests/issues/test_25_respect_cwd_param.py +++ b/tests/issues/test_25_respect_cwd_param.py @@ -5,6 +5,7 @@ from importlib.resources import files import pytest from comfy.api.components.schema.prompt import Prompt +from comfy.cli_args import default_configuration from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import Comfy @@ -23,14 +24,17 @@ _TEST_WORKFLOW = { async def test_respect_cwd_param(): with tempfile.TemporaryDirectory() as tmp_dir: cwd = str(tmp_dir) - config = Configuration(cwd=cwd) - # for finding the custom nodes - config.base_paths = [files(__package__)] + config = default_configuration() + config.cwd = cwd + from comfy.cmd.folder_paths import models_dir assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration" - client = Comfy(config) - prompt = Prompt.validate(_TEST_WORKFLOW) - outputs = await client.queue_prompt_api(prompt) - path_as_imported = outputs.outputs["0"]["path"][0] - assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory" + # for finding the custom nodes + config.base_paths = [str(files(__package__))] + + async with Comfy(config) as client: + prompt = Prompt.validate(_TEST_WORKFLOW) + outputs = await client.queue_prompt_api(prompt) + path_as_imported = outputs.outputs["0"]["path"][0] + assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory" diff --git a/tests/issues/test_46_blacklist_nodes.py b/tests/issues/test_46_blacklist_nodes.py index 57d10902a..b2e97b15b 100644 --- a/tests/issues/test_46_blacklist_nodes.py +++ b/tests/issues/test_46_blacklist_nodes.py @@ -1,9 +1,10 @@ -import pytest from importlib.resources import files -from comfy.api.components.schema.prompt import Prompt -from comfy.cli_args_types import Configuration +import pytest + +from comfy.cli_args import default_configuration from comfy.client.embedded_comfy_client import Comfy +from comfy.execution_context import context_configuration _TEST_WORKFLOW_1 = { "0": { @@ -35,10 +36,15 @@ _TEST_WORKFLOW_2 = { @pytest.mark.asyncio async def test_blacklist_node(): - config = Configuration(blacklist_custom_nodes=['issue_46']) + config = default_configuration() + config.blacklist_custom_nodes = ['issue_46'] # for finding the custom nodes config.base_paths = [str(files(__package__))] + with context_configuration(config): + from comfy.nodes_context import get_nodes + nodes = get_nodes() + assert "ShouldNotExist" not in nodes.NODE_CLASS_MAPPINGS async with Comfy(config) as client: from comfy.cmd.execution import validate_prompt res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[]) diff --git a/tests/unit/test_panics.py b/tests/unit/test_panics.py index c27439ffe..190484727 100644 --- a/tests/unit/test_panics.py +++ b/tests/unit/test_panics.py @@ -190,6 +190,8 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg patch('sys.exit') as mock_exit): try: async with Comfy(configuration=config, executor=executor) as client: + from comfy.cli_args import args + assert len(args.panic_when) == 0 # Queue our failing workflow await client.queue_prompt(create_failing_workflow()) except SystemExit: diff --git a/tests/unit/test_validation.py b/tests/unit/test_validation.py index 77efef87e..ddac5c767 100644 --- a/tests/unit/test_validation.py +++ b/tests/unit/test_validation.py @@ -6,7 +6,7 @@ from pytest_mock import MockerFixture from comfy.cli_args import args from comfy.cmd.execution import validate_prompt -from comfy.nodes_context import nodes +from comfy.nodes_context import get_nodes import uuid @@ -75,6 +75,7 @@ known_models: ContextVar[list[str]] = ContextVar('known_models', default=[]) @pytest.fixture def mock_nodes(mocker: MockerFixture): + nodes = get_nodes() class MockCheckpointLoaderSimple: @staticmethod def INPUT_TYPES():