mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Fixes to tests and configuration, making library use more durable
This commit is contained in:
parent
67f9d3e693
commit
058e5dc634
@ -9,15 +9,17 @@ from typing import Optional
|
|||||||
import configargparse as argparse
|
import configargparse as argparse
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
from . import options
|
|
||||||
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
|
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
|
||||||
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction
|
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction
|
||||||
|
from .component_model.module_property import create_module_properties
|
||||||
|
|
||||||
# todo: move this
|
# todo: move this
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_module_properties = create_module_properties()
|
||||||
|
|
||||||
|
|
||||||
def _create_parser() -> EnhancedConfigArgParser:
|
def _create_parser() -> EnhancedConfigArgParser:
|
||||||
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
|
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
|
||||||
@ -28,7 +30,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
|
|
||||||
parser.add_argument('-w', "--cwd", type=str, default=None,
|
parser.add_argument('-w', "--cwd", type=str, default=None,
|
||||||
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
|
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
|
||||||
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.")
|
parser.add_argument("--base-paths", type=str, nargs='+', default=[], action=FlattenAndAppendAction, help="Additional base paths for custom nodes, models and inputs.")
|
||||||
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
|
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
|
||||||
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||||
@ -149,8 +151,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
|
|
||||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||||
parser.add_argument("--blacklist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
|
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
|
||||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||||
|
|
||||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||||
@ -208,7 +210,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--panic-when',
|
'--panic-when',
|
||||||
action='append',
|
action=FlattenAndAppendAction,
|
||||||
|
nargs='+',
|
||||||
help="""
|
help="""
|
||||||
List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it.
|
List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it.
|
||||||
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
|
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
|
||||||
@ -270,7 +273,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
|
|
||||||
default_db_url = db_config()
|
default_db_url = db_config()
|
||||||
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||||
@ -317,4 +320,19 @@ def default_configuration() -> Configuration:
|
|||||||
return _parse_args(_create_parser())
|
return _parse_args(_create_parser())
|
||||||
|
|
||||||
|
|
||||||
args = _parse_args(args_parsing=options.args_parsing)
|
def cli_args_configuration() -> Configuration:
|
||||||
|
return _parse_args(args_parsing=True)
|
||||||
|
|
||||||
|
|
||||||
|
@_module_properties.getter
|
||||||
|
def _args() -> Configuration:
|
||||||
|
from .execution_context import current_execution_context
|
||||||
|
return current_execution_context().configuration
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"args", # pylint: disable=undefined-all-variable, type: Configuration
|
||||||
|
"default_configuration",
|
||||||
|
"cli_args_configuration",
|
||||||
|
"DEFAULT_VERSION_STRING"
|
||||||
|
]
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from ..distributed.executors import ContextVarExecutor
|
|||||||
from ..distributed.history import History
|
from ..distributed.history import History
|
||||||
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
||||||
from ..distributed.server_stub import ServerStub
|
from ..distributed.server_stub import ServerStub
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context, context_configuration
|
||||||
|
|
||||||
_prompt_executor = threading.local()
|
_prompt_executor = threading.local()
|
||||||
|
|
||||||
@ -57,7 +57,6 @@ def _execute_prompt(
|
|||||||
finally:
|
finally:
|
||||||
detach(token)
|
detach(token)
|
||||||
|
|
||||||
|
|
||||||
async def __execute_prompt(
|
async def __execute_prompt(
|
||||||
prompt: dict,
|
prompt: dict,
|
||||||
prompt_id: str,
|
prompt_id: str,
|
||||||
@ -66,7 +65,16 @@ async def __execute_prompt(
|
|||||||
progress_handler: ExecutorToClientProgress | None,
|
progress_handler: ExecutorToClientProgress | None,
|
||||||
configuration: Configuration | None,
|
configuration: Configuration | None,
|
||||||
partial_execution_targets: list[str] | None) -> dict:
|
partial_execution_targets: list[str] | None) -> dict:
|
||||||
from .. import options
|
with context_configuration(configuration):
|
||||||
|
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
|
||||||
|
|
||||||
|
async def ___execute_prompt(
|
||||||
|
prompt: dict,
|
||||||
|
prompt_id: str,
|
||||||
|
client_id: str,
|
||||||
|
span_context: Context,
|
||||||
|
progress_handler: ExecutorToClientProgress | None,
|
||||||
|
partial_execution_targets: list[str] | None) -> dict:
|
||||||
from ..cmd.execution import PromptExecutor
|
from ..cmd.execution import PromptExecutor
|
||||||
|
|
||||||
progress_handler = progress_handler or ServerStub()
|
progress_handler = progress_handler or ServerStub()
|
||||||
@ -74,13 +82,6 @@ async def __execute_prompt(
|
|||||||
try:
|
try:
|
||||||
prompt_executor: PromptExecutor = _prompt_executor.executor
|
prompt_executor: PromptExecutor = _prompt_executor.executor
|
||||||
except (LookupError, AttributeError):
|
except (LookupError, AttributeError):
|
||||||
if configuration is None:
|
|
||||||
options.enable_args_parsing()
|
|
||||||
else:
|
|
||||||
from ..cmd.main_pre import args
|
|
||||||
args.clear()
|
|
||||||
args.update(configuration)
|
|
||||||
|
|
||||||
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
|
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
|
||||||
# todo: deal with new caching features
|
# todo: deal with new caching features
|
||||||
prompt_executor = PromptExecutor(progress_handler)
|
prompt_executor = PromptExecutor(progress_handler)
|
||||||
@ -117,6 +118,7 @@ async def __execute_prompt(
|
|||||||
|
|
||||||
def _cleanup():
|
def _cleanup():
|
||||||
from ..cmd.execution import PromptExecutor
|
from ..cmd.execution import PromptExecutor
|
||||||
|
from ..nodes_context import invalidate
|
||||||
try:
|
try:
|
||||||
prompt_executor: PromptExecutor = _prompt_executor.executor
|
prompt_executor: PromptExecutor = _prompt_executor.executor
|
||||||
# this should clear all references to output tensors and make it easier to collect back the memory
|
# this should clear all references to output tensors and make it easier to collect back the memory
|
||||||
@ -130,6 +132,10 @@ def _cleanup():
|
|||||||
model_management.soft_empty_cache()
|
model_management.soft_empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
invalidate()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Comfy:
|
class Comfy:
|
||||||
@ -172,6 +178,8 @@ class Comfy:
|
|||||||
self._task_count_lock = RLock()
|
self._task_count_lock = RLock()
|
||||||
self._task_count = 0
|
self._task_count = 0
|
||||||
self._history = History()
|
self._history = History()
|
||||||
|
self._context_stack = []
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@ -183,6 +191,9 @@ class Comfy:
|
|||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self._is_running = True
|
self._is_running = True
|
||||||
|
cm = context_configuration(self._configuration)
|
||||||
|
cm.__enter__()
|
||||||
|
self._context_stack.append(cm)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -193,9 +204,13 @@ class Comfy:
|
|||||||
get_event_loop().run_in_executor(self._executor, _cleanup)
|
get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
self._context_stack.pop().__exit__(*args)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self._is_running = True
|
self._is_running = True
|
||||||
|
cm = context_configuration(self._configuration)
|
||||||
|
cm.__enter__()
|
||||||
|
self._context_stack.append(cm)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, *args):
|
||||||
@ -207,6 +222,7 @@ class Comfy:
|
|||||||
|
|
||||||
self._executor.shutdown(wait=True)
|
self._executor.shutdown(wait=True)
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
self._context_stack.pop().__exit__(*args)
|
||||||
|
|
||||||
async def queue_prompt_api(self,
|
async def queue_prompt_api(self,
|
||||||
prompt: PromptDict | str | dict,
|
prompt: PromptDict | str | dict,
|
||||||
|
|||||||
@ -30,9 +30,9 @@ from comfy_execution.graph_utils import is_link, GraphBuilder
|
|||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io
|
from comfy_api.latest import io
|
||||||
|
from ..execution_context import current_execution_context
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..cli_args import args
|
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||||
@ -675,7 +675,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
"current_inputs": input_data_formatted
|
"current_inputs": input_data_formatted
|
||||||
}
|
}
|
||||||
|
|
||||||
if should_panic_on_exception(ex, args.panic_when):
|
if should_panic_on_exception(ex, current_execution_context().configuration.panic_when):
|
||||||
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||||
|
|
||||||
def sys_exit(*args):
|
def sys_exit(*args):
|
||||||
|
|||||||
@ -148,6 +148,13 @@ def setup_database():
|
|||||||
|
|
||||||
|
|
||||||
async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||||
|
from ..execution_context import context_configuration
|
||||||
|
from ..cli_args import cli_args_configuration
|
||||||
|
with context_configuration(cli_args_configuration()):
|
||||||
|
await __start_comfyui(from_script_dir=from_script_dir)
|
||||||
|
|
||||||
|
|
||||||
|
async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
||||||
"""
|
"""
|
||||||
Runs ComfyUI's frontend and backend like upstream.
|
Runs ComfyUI's frontend and backend like upstream.
|
||||||
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
|
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import fsspec
|
|||||||
|
|
||||||
from .. import options
|
from .. import options
|
||||||
from ..app import logger
|
from ..app import logger
|
||||||
|
from ..cli_args_types import Configuration
|
||||||
from ..component_model import package_filesystem
|
from ..component_model import package_filesystem
|
||||||
|
|
||||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
||||||
@ -164,6 +165,7 @@ def _register_fsspec_fs():
|
|||||||
package_filesystem.PkgResourcesFileSystem,
|
package_filesystem.PkgResourcesFileSystem,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
args: Configuration
|
||||||
|
|
||||||
_configure_logging()
|
_configure_logging()
|
||||||
_fix_pytorch_240()
|
_fix_pytorch_240()
|
||||||
|
|||||||
@ -2,6 +2,17 @@ import sys
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
def create_module_properties():
|
def create_module_properties():
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
>>> _module_properties = create_module_properties()
|
||||||
|
|
||||||
|
>>> @_module_properties.getter
|
||||||
|
>>> def _nodes():
|
||||||
|
>>> return ...
|
||||||
|
|
||||||
|
This creates nodes as a property
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
properties = {}
|
properties = {}
|
||||||
patched_modules = set()
|
patched_modules = set()
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from contextvars import ContextVar
|
|||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass, replace
|
||||||
from typing import Optional, Final
|
from typing import Optional, Final
|
||||||
|
|
||||||
|
from .cli_args import cli_args_configuration
|
||||||
|
from .cli_args_types import Configuration
|
||||||
from .component_model import cvpickle
|
from .component_model import cvpickle
|
||||||
from .component_model.executor_types import ExecutorToClientProgress
|
from .component_model.executor_types import ExecutorToClientProgress
|
||||||
from .component_model.folder_path_types import FolderNames
|
from .component_model.folder_path_types import FolderNames
|
||||||
@ -23,6 +25,7 @@ class ExecutionContext:
|
|||||||
list_index: Optional[int] = None
|
list_index: Optional[int] = None
|
||||||
inference_mode: bool = True
|
inference_mode: bool = True
|
||||||
progress_registry: Optional[AbstractProgressRegistry] = None
|
progress_registry: Optional[AbstractProgressRegistry] = None
|
||||||
|
configuration: Optional[Configuration] = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
@ -34,7 +37,7 @@ class ExecutionContext:
|
|||||||
yield self.list_index
|
yield self.list_index
|
||||||
|
|
||||||
|
|
||||||
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
|
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration()))
|
||||||
# enables context var propagation across process boundaries for process pool executors
|
# enables context var propagation across process boundaries for process pool executors
|
||||||
cvpickle.register_contextvar(comfyui_execution_context, __name__)
|
cvpickle.register_contextvar(comfyui_execution_context, __name__)
|
||||||
|
|
||||||
@ -52,6 +55,17 @@ def _new_execution_context(ctx: ExecutionContext):
|
|||||||
comfyui_execution_context.reset(token)
|
comfyui_execution_context.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def context_configuration(configuration: Optional[Configuration] = None):
|
||||||
|
current_ctx = current_execution_context()
|
||||||
|
if configuration is None:
|
||||||
|
from .cli_args import cli_args_configuration
|
||||||
|
configuration = cli_args_configuration()
|
||||||
|
new_ctx = replace(current_ctx, configuration=configuration)
|
||||||
|
with _new_execution_context(new_ctx):
|
||||||
|
yield new_ctx
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
|
def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
|
||||||
current_ctx = current_execution_context()
|
current_ctx = current_execution_context()
|
||||||
|
|||||||
@ -4,10 +4,12 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from importlib.metadata import entry_points
|
from importlib.metadata import entry_points
|
||||||
|
from threading import RLock
|
||||||
|
|
||||||
from opentelemetry.trace import Span, Status, StatusCode
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
|
|
||||||
@ -17,8 +19,9 @@ from comfy_api.version_list import supported_versions
|
|||||||
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
from ..component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
|
from ..execution_context import current_execution_context
|
||||||
|
|
||||||
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
_nodes_local = threading.local()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -46,8 +49,6 @@ def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
|
|||||||
return exported_nodes
|
return exported_nodes
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||||
print_import_times=False,
|
print_import_times=False,
|
||||||
raise_on_failure=False,
|
raise_on_failure=False,
|
||||||
@ -116,7 +117,11 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
||||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||||
# now actually import the nodes, to improve control of node loading order
|
# now actually import the nodes, to improve control of node loading order
|
||||||
from ..cli_args import args
|
try:
|
||||||
|
_nodes_available_at_startup = _nodes_local.nodes
|
||||||
|
except (LookupError, AttributeError):
|
||||||
|
_nodes_available_at_startup = _nodes_local.nodes = ExportedNodes()
|
||||||
|
args = current_execution_context().configuration
|
||||||
|
|
||||||
# todo: this is some truly braindead stuff
|
# todo: this is some truly braindead stuff
|
||||||
register_versions([
|
register_versions([
|
||||||
@ -126,48 +131,47 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
|||||||
) for v in supported_versions
|
) for v in supported_versions
|
||||||
])
|
])
|
||||||
|
|
||||||
# only load these nodes once
|
_nodes_available_at_startup.clear()
|
||||||
if len(_nodes_available_at_startup) == 0:
|
|
||||||
|
|
||||||
# import base_nodes first
|
# import base_nodes first
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
||||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||||
|
|
||||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||||
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
||||||
# this is the list of default nodes to import
|
# this is the list of default nodes to import
|
||||||
base_nodes,
|
base_nodes,
|
||||||
comfy_extras_nodes
|
comfy_extras_nodes
|
||||||
]),
|
]),
|
||||||
ExportedNodes())
|
ExportedNodes())
|
||||||
custom_nodes_mappings = ExportedNodes()
|
custom_nodes_mappings = ExportedNodes()
|
||||||
|
|
||||||
if args.disable_all_custom_nodes:
|
if args.disable_all_custom_nodes:
|
||||||
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
||||||
_nodes_available_at_startup.update(base_and_extra)
|
_nodes_available_at_startup.update(base_and_extra)
|
||||||
return _nodes_available_at_startup
|
return _nodes_available_at_startup
|
||||||
|
|
||||||
# load from entrypoints
|
# load from entrypoints
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||||
# Load the module associated with the current entry point
|
# Load the module associated with the current entry point
|
||||||
try:
|
try:
|
||||||
module = entry_point.load()
|
module = entry_point.load()
|
||||||
except ModuleNotFoundError as module_not_found_error:
|
except ModuleNotFoundError as module_not_found_error:
|
||||||
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Ensure that what we've loaded is indeed a module
|
# Ensure that what we've loaded is indeed a module
|
||||||
if isinstance(module, types.ModuleType):
|
if isinstance(module, types.ModuleType):
|
||||||
custom_nodes_mappings.update(
|
custom_nodes_mappings.update(
|
||||||
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
|
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
|
||||||
|
|
||||||
# load the vanilla custom nodes last
|
# load the vanilla custom nodes last
|
||||||
if vanilla_custom_nodes:
|
if vanilla_custom_nodes:
|
||||||
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
|
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
|
||||||
|
|
||||||
# don't allow custom nodes to overwrite base nodes
|
# don't allow custom nodes to overwrite base nodes
|
||||||
custom_nodes_mappings -= base_and_extra
|
custom_nodes_mappings -= base_and_extra
|
||||||
|
|
||||||
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
|
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
|
||||||
return _nodes_available_at_startup
|
return _nodes_available_at_startup
|
||||||
|
|||||||
@ -198,6 +198,10 @@ class ExportedNodes:
|
|||||||
def __bool__(self):
|
def __bool__(self):
|
||||||
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
|
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.NODE_CLASS_MAPPINGS.clear()
|
||||||
|
self.EXTENSION_WEB_DIRS.clear()
|
||||||
|
self.NODE_DISPLAY_NAME_MAPPINGS.clear()
|
||||||
|
|
||||||
class _ExportedNodesAsChainMap(ExportedNodes):
|
class _ExportedNodesAsChainMap(ExportedNodes):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -1,15 +1,26 @@
|
|||||||
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
||||||
|
import threading
|
||||||
|
|
||||||
import lazy_object_proxy
|
import lazy_object_proxy
|
||||||
|
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
from .nodes.package import import_all_nodes_in_workspace
|
from .nodes.package import import_all_nodes_in_workspace
|
||||||
from .nodes.package_typing import ExportedNodes, exported_nodes_view
|
from .nodes.package_typing import ExportedNodes, exported_nodes_view
|
||||||
|
|
||||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
_nodes_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate():
|
||||||
|
_nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||||
|
|
||||||
|
|
||||||
def get_nodes() -> ExportedNodes:
|
def get_nodes() -> ExportedNodes:
|
||||||
current_ctx = current_execution_context()
|
current_ctx = current_execution_context()
|
||||||
|
try:
|
||||||
|
nodes = _nodes_local.nodes
|
||||||
|
except (LookupError, AttributeError):
|
||||||
|
nodes = _nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||||
|
|
||||||
if len(current_ctx.custom_nodes) == 0:
|
if len(current_ctx.custom_nodes) == 0:
|
||||||
return nodes
|
return nodes
|
||||||
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
|
args_parsing = True
|
||||||
|
|
||||||
args_parsing = False
|
|
||||||
|
|
||||||
def enable_args_parsing(enable=True):
|
def enable_args_parsing(enable=True):
|
||||||
global args_parsing
|
pass
|
||||||
args_parsing = enable
|
|
||||||
|
|||||||
@ -128,12 +128,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
|||||||
if return_metadata:
|
if return_metadata:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 0:
|
message = str(e)
|
||||||
message = e.args[0]
|
if "HeaderTooLarge" in message:
|
||||||
if "HeaderTooLarge" in message:
|
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
|
||||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
|
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
|
||||||
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
|
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
|
||||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
|
|
||||||
raise e
|
raise e
|
||||||
elif ckpt.lower().endswith("index.json"):
|
elif ckpt.lower().endswith("index.json"):
|
||||||
# from accelerate
|
# from accelerate
|
||||||
|
|||||||
@ -187,7 +187,7 @@ class TestExecution:
|
|||||||
|
|
||||||
await client.run(g)
|
await client.run(g)
|
||||||
mask.inputs['value'] = 0.4
|
mask.inputs['value'] = 0.4
|
||||||
result2 = client.run(g)
|
result2 = await client.run(g)
|
||||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from importlib.resources import files
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from comfy.api.components.schema.prompt import Prompt
|
from comfy.api.components.schema.prompt import Prompt
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
from comfy.cli_args_types import Configuration
|
from comfy.cli_args_types import Configuration
|
||||||
from comfy.client.embedded_comfy_client import Comfy
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
|
||||||
@ -23,14 +24,17 @@ _TEST_WORKFLOW = {
|
|||||||
async def test_respect_cwd_param():
|
async def test_respect_cwd_param():
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
cwd = str(tmp_dir)
|
cwd = str(tmp_dir)
|
||||||
config = Configuration(cwd=cwd)
|
config = default_configuration()
|
||||||
# for finding the custom nodes
|
config.cwd = cwd
|
||||||
config.base_paths = [files(__package__)]
|
|
||||||
from comfy.cmd.folder_paths import models_dir
|
from comfy.cmd.folder_paths import models_dir
|
||||||
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
|
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
|
||||||
|
|
||||||
client = Comfy(config)
|
# for finding the custom nodes
|
||||||
prompt = Prompt.validate(_TEST_WORKFLOW)
|
config.base_paths = [str(files(__package__))]
|
||||||
outputs = await client.queue_prompt_api(prompt)
|
|
||||||
path_as_imported = outputs.outputs["0"]["path"][0]
|
async with Comfy(config) as client:
|
||||||
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"
|
prompt = Prompt.validate(_TEST_WORKFLOW)
|
||||||
|
outputs = await client.queue_prompt_api(prompt)
|
||||||
|
path_as_imported = outputs.outputs["0"]["path"][0]
|
||||||
|
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import pytest
|
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
|
|
||||||
from comfy.api.components.schema.prompt import Prompt
|
import pytest
|
||||||
from comfy.cli_args_types import Configuration
|
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
from comfy.client.embedded_comfy_client import Comfy
|
from comfy.client.embedded_comfy_client import Comfy
|
||||||
|
from comfy.execution_context import context_configuration
|
||||||
|
|
||||||
_TEST_WORKFLOW_1 = {
|
_TEST_WORKFLOW_1 = {
|
||||||
"0": {
|
"0": {
|
||||||
@ -35,10 +36,15 @@ _TEST_WORKFLOW_2 = {
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_blacklist_node():
|
async def test_blacklist_node():
|
||||||
config = Configuration(blacklist_custom_nodes=['issue_46'])
|
config = default_configuration()
|
||||||
|
config.blacklist_custom_nodes = ['issue_46']
|
||||||
# for finding the custom nodes
|
# for finding the custom nodes
|
||||||
config.base_paths = [str(files(__package__))]
|
config.base_paths = [str(files(__package__))]
|
||||||
|
|
||||||
|
with context_configuration(config):
|
||||||
|
from comfy.nodes_context import get_nodes
|
||||||
|
nodes = get_nodes()
|
||||||
|
assert "ShouldNotExist" not in nodes.NODE_CLASS_MAPPINGS
|
||||||
async with Comfy(config) as client:
|
async with Comfy(config) as client:
|
||||||
from comfy.cmd.execution import validate_prompt
|
from comfy.cmd.execution import validate_prompt
|
||||||
res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[])
|
res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[])
|
||||||
|
|||||||
@ -190,6 +190,8 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
|
|||||||
patch('sys.exit') as mock_exit):
|
patch('sys.exit') as mock_exit):
|
||||||
try:
|
try:
|
||||||
async with Comfy(configuration=config, executor=executor) as client:
|
async with Comfy(configuration=config, executor=executor) as client:
|
||||||
|
from comfy.cli_args import args
|
||||||
|
assert len(args.panic_when) == 0
|
||||||
# Queue our failing workflow
|
# Queue our failing workflow
|
||||||
await client.queue_prompt(create_failing_workflow())
|
await client.queue_prompt(create_failing_workflow())
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from pytest_mock import MockerFixture
|
|||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.cmd.execution import validate_prompt
|
from comfy.cmd.execution import validate_prompt
|
||||||
from comfy.nodes_context import nodes
|
from comfy.nodes_context import get_nodes
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -75,6 +75,7 @@ known_models: ContextVar[list[str]] = ContextVar('known_models', default=[])
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_nodes(mocker: MockerFixture):
|
def mock_nodes(mocker: MockerFixture):
|
||||||
|
nodes = get_nodes()
|
||||||
class MockCheckpointLoaderSimple:
|
class MockCheckpointLoaderSimple:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def INPUT_TYPES():
|
def INPUT_TYPES():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user