Fixes to tests and configuration, making library use more durable

This commit is contained in:
doctorpangloss 2025-10-23 19:46:40 -07:00
parent 67f9d3e693
commit 058e5dc634
17 changed files with 182 additions and 84 deletions

View File

@ -9,15 +9,17 @@ from typing import Optional
import configargparse as argparse
from . import __version__
from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction
from .component_model.module_property import create_module_properties
# todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
logger = logging.getLogger(__name__)
_module_properties = create_module_properties()
def _create_parser() -> EnhancedConfigArgParser:
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
@ -28,7 +30,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument('-w', "--cwd", type=str, default=None,
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.")
parser.add_argument("--base-paths", type=str, nargs='+', default=[], action=FlattenAndAppendAction, help="Additional base paths for custom nodes, models and inputs.")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
@ -149,8 +151,8 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--blacklist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
@ -208,7 +210,8 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument(
'--panic-when',
action='append',
action=FlattenAndAppendAction,
nargs='+',
help="""
List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it.
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
@ -270,7 +273,7 @@ def _create_parser() -> EnhancedConfigArgParser:
default_db_url = db_config()
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
# now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'):
@ -317,4 +320,19 @@ def default_configuration() -> Configuration:
return _parse_args(_create_parser())
args = _parse_args(args_parsing=options.args_parsing)
def cli_args_configuration() -> Configuration:
return _parse_args(args_parsing=True)
@_module_properties.getter
def _args() -> Configuration:
from .execution_context import current_execution_context
return current_execution_context().configuration
__all__ = [
"args", # pylint: disable=undefined-all-variable, type: Configuration
"default_configuration",
"cli_args_configuration",
"DEFAULT_VERSION_STRING"
]

View File

@ -29,7 +29,7 @@ from ..distributed.executors import ContextVarExecutor
from ..distributed.history import History
from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub
from ..execution_context import current_execution_context
from ..execution_context import current_execution_context, context_configuration
_prompt_executor = threading.local()
@ -57,7 +57,6 @@ def _execute_prompt(
finally:
detach(token)
async def __execute_prompt(
prompt: dict,
prompt_id: str,
@ -66,7 +65,16 @@ async def __execute_prompt(
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None,
partial_execution_targets: list[str] | None) -> dict:
from .. import options
with context_configuration(configuration):
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
async def ___execute_prompt(
prompt: dict,
prompt_id: str,
client_id: str,
span_context: Context,
progress_handler: ExecutorToClientProgress | None,
partial_execution_targets: list[str] | None) -> dict:
from ..cmd.execution import PromptExecutor
progress_handler = progress_handler or ServerStub()
@ -74,13 +82,6 @@ async def __execute_prompt(
try:
prompt_executor: PromptExecutor = _prompt_executor.executor
except (LookupError, AttributeError):
if configuration is None:
options.enable_args_parsing()
else:
from ..cmd.main_pre import args
args.clear()
args.update(configuration)
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
# todo: deal with new caching features
prompt_executor = PromptExecutor(progress_handler)
@ -117,6 +118,7 @@ async def __execute_prompt(
def _cleanup():
from ..cmd.execution import PromptExecutor
from ..nodes_context import invalidate
try:
prompt_executor: PromptExecutor = _prompt_executor.executor
# this should clear all references to output tensors and make it easier to collect back the memory
@ -130,6 +132,10 @@ def _cleanup():
model_management.soft_empty_cache()
except:
pass
try:
invalidate()
except:
pass
class Comfy:
@ -172,6 +178,8 @@ class Comfy:
self._task_count_lock = RLock()
self._task_count = 0
self._history = History()
self._context_stack = []
@property
def is_running(self) -> bool:
@ -183,6 +191,9 @@ class Comfy:
def __enter__(self):
self._is_running = True
cm = context_configuration(self._configuration)
cm.__enter__()
self._context_stack.append(cm)
return self
@property
@ -193,9 +204,13 @@ class Comfy:
get_event_loop().run_in_executor(self._executor, _cleanup)
self._executor.shutdown(wait=True)
self._is_running = False
self._context_stack.pop().__exit__(*args)
async def __aenter__(self):
self._is_running = True
cm = context_configuration(self._configuration)
cm.__enter__()
self._context_stack.append(cm)
return self
async def __aexit__(self, *args):
@ -207,6 +222,7 @@ class Comfy:
self._executor.shutdown(wait=True)
self._is_running = False
self._context_stack.pop().__exit__(*args)
async def queue_prompt_api(self,
prompt: PromptDict | str | dict,

View File

@ -30,9 +30,9 @@ from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io
from ..execution_context import current_execution_context
from .. import interruption
from .. import model_management
from ..cli_args import args
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
@ -675,7 +675,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
"current_inputs": input_data_formatted
}
if should_panic_on_exception(ex, args.panic_when):
if should_panic_on_exception(ex, current_execution_context().configuration.panic_when):
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
def sys_exit(*args):

View File

@ -148,6 +148,13 @@ def setup_database():
async def _start_comfyui(from_script_dir: Optional[Path] = None):
from ..execution_context import context_configuration
from ..cli_args import cli_args_configuration
with context_configuration(cli_args_configuration()):
await __start_comfyui(from_script_dir=from_script_dir)
async def __start_comfyui(from_script_dir: Optional[Path] = None):
"""
Runs ComfyUI's frontend and backend like upstream.
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path

View File

@ -18,6 +18,7 @@ import fsspec
from .. import options
from ..app import logger
from ..cli_args_types import Configuration
from ..component_model import package_filesystem
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
@ -164,6 +165,7 @@ def _register_fsspec_fs():
package_filesystem.PkgResourcesFileSystem,
)
args: Configuration
_configure_logging()
_fix_pytorch_240()

View File

@ -2,6 +2,17 @@ import sys
from functools import wraps
def create_module_properties():
"""
Example:
>>> _module_properties = create_module_properties()
>>> @_module_properties.getter
>>> def _nodes():
>>> return ...
This creates nodes as a property
:return:
"""
properties = {}
patched_modules = set()

View File

@ -5,6 +5,8 @@ from contextvars import ContextVar
from dataclasses import dataclass, replace
from typing import Optional, Final
from .cli_args import cli_args_configuration
from .cli_args_types import Configuration
from .component_model import cvpickle
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames
@ -23,6 +25,7 @@ class ExecutionContext:
list_index: Optional[int] = None
inference_mode: bool = True
progress_registry: Optional[AbstractProgressRegistry] = None
configuration: Optional[Configuration] = None
def __iter__(self):
"""
@ -34,7 +37,7 @@ class ExecutionContext:
yield self.list_index
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration()))
# enables context var propagation across process boundaries for process pool executors
cvpickle.register_contextvar(comfyui_execution_context, __name__)
@ -52,6 +55,17 @@ def _new_execution_context(ctx: ExecutionContext):
comfyui_execution_context.reset(token)
@contextmanager
def context_configuration(configuration: Optional[Configuration] = None):
current_ctx = current_execution_context()
if configuration is None:
from .cli_args import cli_args_configuration
configuration = cli_args_configuration()
new_ctx = replace(current_ctx, configuration=configuration)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
current_ctx = current_execution_context()

View File

@ -4,10 +4,12 @@ import importlib
import logging
import os
import pkgutil
import threading
import time
import types
from functools import reduce
from importlib.metadata import entry_points
from threading import RLock
from opentelemetry.trace import Span, Status, StatusCode
@ -17,8 +19,9 @@ from comfy_api.version_list import supported_versions
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
from .package_typing import ExportedNodes
from ..component_model.files import get_package_as_path
from ..execution_context import current_execution_context
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
_nodes_local = threading.local()
logger = logging.getLogger(__name__)
@ -46,8 +49,6 @@ def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
return exported_nodes
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
print_import_times=False,
raise_on_failure=False,
@ -116,7 +117,11 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
@tracer.start_as_current_span("Import All Nodes In Workspace")
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order
from ..cli_args import args
try:
_nodes_available_at_startup = _nodes_local.nodes
except (LookupError, AttributeError):
_nodes_available_at_startup = _nodes_local.nodes = ExportedNodes()
args = current_execution_context().configuration
# todo: this is some truly braindead stuff
register_versions([
@ -126,48 +131,47 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
) for v in supported_versions
])
# only load these nodes once
if len(_nodes_available_at_startup) == 0:
_nodes_available_at_startup.clear()
# import base_nodes first
from . import base_nodes
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
# import base_nodes first
from . import base_nodes
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
base_and_extra = reduce(lambda x, y: x.update(y),
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
# this is the list of default nodes to import
base_nodes,
comfy_extras_nodes
]),
ExportedNodes())
custom_nodes_mappings = ExportedNodes()
base_and_extra = reduce(lambda x, y: x.update(y),
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
# this is the list of default nodes to import
base_nodes,
comfy_extras_nodes
]),
ExportedNodes())
custom_nodes_mappings = ExportedNodes()
if args.disable_all_custom_nodes:
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
_nodes_available_at_startup.update(base_and_extra)
return _nodes_available_at_startup
if args.disable_all_custom_nodes:
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
_nodes_available_at_startup.update(base_and_extra)
return _nodes_available_at_startup
# load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
# Load the module associated with the current entry point
try:
module = entry_point.load()
except ModuleNotFoundError as module_not_found_error:
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
continue
# load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
# Load the module associated with the current entry point
try:
module = entry_point.load()
except ModuleNotFoundError as module_not_found_error:
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
continue
# Ensure that what we've loaded is indeed a module
if isinstance(module, types.ModuleType):
custom_nodes_mappings.update(
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
# Ensure that what we've loaded is indeed a module
if isinstance(module, types.ModuleType):
custom_nodes_mappings.update(
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
# load the vanilla custom nodes last
if vanilla_custom_nodes:
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
# load the vanilla custom nodes last
if vanilla_custom_nodes:
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
# don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra
# don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
return _nodes_available_at_startup

View File

@ -198,6 +198,10 @@ class ExportedNodes:
def __bool__(self):
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
def clear(self):
self.NODE_CLASS_MAPPINGS.clear()
self.EXTENSION_WEB_DIRS.clear()
self.NODE_DISPLAY_NAME_MAPPINGS.clear()
class _ExportedNodesAsChainMap(ExportedNodes):
@classmethod

View File

@ -1,15 +1,26 @@
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
import threading
import lazy_object_proxy
from .execution_context import current_execution_context
from .nodes.package import import_all_nodes_in_workspace
from .nodes.package_typing import ExportedNodes, exported_nodes_view
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
_nodes_local = threading.local()
def invalidate():
_nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
def get_nodes() -> ExportedNodes:
current_ctx = current_execution_context()
try:
nodes = _nodes_local.nodes
except (LookupError, AttributeError):
nodes = _nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
if len(current_ctx.custom_nodes) == 0:
return nodes
return exported_nodes_view(nodes, current_ctx.custom_nodes)

View File

@ -1,6 +1,5 @@
args_parsing = True
args_parsing = False
def enable_args_parsing(enable=True):
global args_parsing
args_parsing = enable
pass

View File

@ -128,12 +128,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
if return_metadata:
metadata = f.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
message = str(e)
if "HeaderTooLarge" in message:
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
raise e
elif ckpt.lower().endswith("index.json"):
# from accelerate

View File

@ -187,7 +187,7 @@ class TestExecution:
await client.run(g)
mask.inputs['value'] = 0.4
result2 = client.run(g)
result2 = await client.run(g)
assert not result2.did_run(input1), "Input1 should have been cached"
assert not result2.did_run(input2), "Input2 should have been cached"

View File

@ -5,6 +5,7 @@ from importlib.resources import files
import pytest
from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import Comfy
@ -23,14 +24,17 @@ _TEST_WORKFLOW = {
async def test_respect_cwd_param():
with tempfile.TemporaryDirectory() as tmp_dir:
cwd = str(tmp_dir)
config = Configuration(cwd=cwd)
# for finding the custom nodes
config.base_paths = [files(__package__)]
config = default_configuration()
config.cwd = cwd
from comfy.cmd.folder_paths import models_dir
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
client = Comfy(config)
prompt = Prompt.validate(_TEST_WORKFLOW)
outputs = await client.queue_prompt_api(prompt)
path_as_imported = outputs.outputs["0"]["path"][0]
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"
# for finding the custom nodes
config.base_paths = [str(files(__package__))]
async with Comfy(config) as client:
prompt = Prompt.validate(_TEST_WORKFLOW)
outputs = await client.queue_prompt_api(prompt)
path_as_imported = outputs.outputs["0"]["path"][0]
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"

View File

@ -1,9 +1,10 @@
import pytest
from importlib.resources import files
from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args_types import Configuration
import pytest
from comfy.cli_args import default_configuration
from comfy.client.embedded_comfy_client import Comfy
from comfy.execution_context import context_configuration
_TEST_WORKFLOW_1 = {
"0": {
@ -35,10 +36,15 @@ _TEST_WORKFLOW_2 = {
@pytest.mark.asyncio
async def test_blacklist_node():
config = Configuration(blacklist_custom_nodes=['issue_46'])
config = default_configuration()
config.blacklist_custom_nodes = ['issue_46']
# for finding the custom nodes
config.base_paths = [str(files(__package__))]
with context_configuration(config):
from comfy.nodes_context import get_nodes
nodes = get_nodes()
assert "ShouldNotExist" not in nodes.NODE_CLASS_MAPPINGS
async with Comfy(config) as client:
from comfy.cmd.execution import validate_prompt
res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[])

View File

@ -190,6 +190,8 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
patch('sys.exit') as mock_exit):
try:
async with Comfy(configuration=config, executor=executor) as client:
from comfy.cli_args import args
assert len(args.panic_when) == 0
# Queue our failing workflow
await client.queue_prompt(create_failing_workflow())
except SystemExit:

View File

@ -6,7 +6,7 @@ from pytest_mock import MockerFixture
from comfy.cli_args import args
from comfy.cmd.execution import validate_prompt
from comfy.nodes_context import nodes
from comfy.nodes_context import get_nodes
import uuid
@ -75,6 +75,7 @@ known_models: ContextVar[list[str]] = ContextVar('known_models', default=[])
@pytest.fixture
def mock_nodes(mocker: MockerFixture):
nodes = get_nodes()
class MockCheckpointLoaderSimple:
@staticmethod
def INPUT_TYPES():