mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Fixes to tests and configuration, making library use more durable
This commit is contained in:
parent
67f9d3e693
commit
058e5dc634
@ -9,15 +9,17 @@ from typing import Optional
|
||||
import configargparse as argparse
|
||||
|
||||
from . import __version__
|
||||
from . import options
|
||||
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
|
||||
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config, FlattenAndAppendAction
|
||||
from .component_model.module_property import create_module_properties
|
||||
|
||||
# todo: move this
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
|
||||
|
||||
def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json', 'config.cfg', 'config.ini'],
|
||||
@ -28,7 +30,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
|
||||
parser.add_argument('-w', "--cwd", type=str, default=None,
|
||||
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
|
||||
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.")
|
||||
parser.add_argument("--base-paths", type=str, nargs='+', default=[], action=FlattenAndAppendAction, help="Additional base paths for custom nodes, models and inputs.")
|
||||
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
|
||||
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
@ -149,8 +151,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--blacklist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
@ -208,7 +210,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
|
||||
parser.add_argument(
|
||||
'--panic-when',
|
||||
action='append',
|
||||
action=FlattenAndAppendAction,
|
||||
nargs='+',
|
||||
help="""
|
||||
List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it.
|
||||
Example: --panic-when=torch.cuda.OutOfMemoryError. Can be specified multiple times or as a
|
||||
@ -270,7 +273,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
|
||||
default_db_url = db_config()
|
||||
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||
@ -317,4 +320,19 @@ def default_configuration() -> Configuration:
|
||||
return _parse_args(_create_parser())
|
||||
|
||||
|
||||
args = _parse_args(args_parsing=options.args_parsing)
|
||||
def cli_args_configuration() -> Configuration:
|
||||
return _parse_args(args_parsing=True)
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _args() -> Configuration:
|
||||
from .execution_context import current_execution_context
|
||||
return current_execution_context().configuration
|
||||
|
||||
|
||||
__all__ = [
|
||||
"args", # pylint: disable=undefined-all-variable, type: Configuration
|
||||
"default_configuration",
|
||||
"cli_args_configuration",
|
||||
"DEFAULT_VERSION_STRING"
|
||||
]
|
||||
|
||||
@ -29,7 +29,7 @@ from ..distributed.executors import ContextVarExecutor
|
||||
from ..distributed.history import History
|
||||
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..execution_context import current_execution_context
|
||||
from ..execution_context import current_execution_context, context_configuration
|
||||
|
||||
_prompt_executor = threading.local()
|
||||
|
||||
@ -57,7 +57,6 @@ def _execute_prompt(
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
|
||||
async def __execute_prompt(
|
||||
prompt: dict,
|
||||
prompt_id: str,
|
||||
@ -66,7 +65,16 @@ async def __execute_prompt(
|
||||
progress_handler: ExecutorToClientProgress | None,
|
||||
configuration: Configuration | None,
|
||||
partial_execution_targets: list[str] | None) -> dict:
|
||||
from .. import options
|
||||
with context_configuration(configuration):
|
||||
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
|
||||
|
||||
async def ___execute_prompt(
|
||||
prompt: dict,
|
||||
prompt_id: str,
|
||||
client_id: str,
|
||||
span_context: Context,
|
||||
progress_handler: ExecutorToClientProgress | None,
|
||||
partial_execution_targets: list[str] | None) -> dict:
|
||||
from ..cmd.execution import PromptExecutor
|
||||
|
||||
progress_handler = progress_handler or ServerStub()
|
||||
@ -74,13 +82,6 @@ async def __execute_prompt(
|
||||
try:
|
||||
prompt_executor: PromptExecutor = _prompt_executor.executor
|
||||
except (LookupError, AttributeError):
|
||||
if configuration is None:
|
||||
options.enable_args_parsing()
|
||||
else:
|
||||
from ..cmd.main_pre import args
|
||||
args.clear()
|
||||
args.update(configuration)
|
||||
|
||||
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
|
||||
# todo: deal with new caching features
|
||||
prompt_executor = PromptExecutor(progress_handler)
|
||||
@ -117,6 +118,7 @@ async def __execute_prompt(
|
||||
|
||||
def _cleanup():
|
||||
from ..cmd.execution import PromptExecutor
|
||||
from ..nodes_context import invalidate
|
||||
try:
|
||||
prompt_executor: PromptExecutor = _prompt_executor.executor
|
||||
# this should clear all references to output tensors and make it easier to collect back the memory
|
||||
@ -130,6 +132,10 @@ def _cleanup():
|
||||
model_management.soft_empty_cache()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
invalidate()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class Comfy:
|
||||
@ -172,6 +178,8 @@ class Comfy:
|
||||
self._task_count_lock = RLock()
|
||||
self._task_count = 0
|
||||
self._history = History()
|
||||
self._context_stack = []
|
||||
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -183,6 +191,9 @@ class Comfy:
|
||||
|
||||
def __enter__(self):
|
||||
self._is_running = True
|
||||
cm = context_configuration(self._configuration)
|
||||
cm.__enter__()
|
||||
self._context_stack.append(cm)
|
||||
return self
|
||||
|
||||
@property
|
||||
@ -193,9 +204,13 @@ class Comfy:
|
||||
get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||
self._executor.shutdown(wait=True)
|
||||
self._is_running = False
|
||||
self._context_stack.pop().__exit__(*args)
|
||||
|
||||
async def __aenter__(self):
|
||||
self._is_running = True
|
||||
cm = context_configuration(self._configuration)
|
||||
cm.__enter__()
|
||||
self._context_stack.append(cm)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
@ -207,6 +222,7 @@ class Comfy:
|
||||
|
||||
self._executor.shutdown(wait=True)
|
||||
self._is_running = False
|
||||
self._context_stack.pop().__exit__(*args)
|
||||
|
||||
async def queue_prompt_api(self,
|
||||
prompt: PromptDict | str | dict,
|
||||
|
||||
@ -30,9 +30,9 @@ from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
from ..execution_context import current_execution_context
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..cli_args import args
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||
@ -675,7 +675,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
"current_inputs": input_data_formatted
|
||||
}
|
||||
|
||||
if should_panic_on_exception(ex, args.panic_when):
|
||||
if should_panic_on_exception(ex, current_execution_context().configuration.panic_when):
|
||||
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||
|
||||
def sys_exit(*args):
|
||||
|
||||
@ -148,6 +148,13 @@ def setup_database():
|
||||
|
||||
|
||||
async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
from ..execution_context import context_configuration
|
||||
from ..cli_args import cli_args_configuration
|
||||
with context_configuration(cli_args_configuration()):
|
||||
await __start_comfyui(from_script_dir=from_script_dir)
|
||||
|
||||
|
||||
async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
"""
|
||||
Runs ComfyUI's frontend and backend like upstream.
|
||||
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
|
||||
|
||||
@ -18,6 +18,7 @@ import fsspec
|
||||
|
||||
from .. import options
|
||||
from ..app import logger
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model import package_filesystem
|
||||
|
||||
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
|
||||
@ -164,6 +165,7 @@ def _register_fsspec_fs():
|
||||
package_filesystem.PkgResourcesFileSystem,
|
||||
)
|
||||
|
||||
args: Configuration
|
||||
|
||||
_configure_logging()
|
||||
_fix_pytorch_240()
|
||||
|
||||
@ -2,6 +2,17 @@ import sys
|
||||
from functools import wraps
|
||||
|
||||
def create_module_properties():
|
||||
"""
|
||||
Example:
|
||||
>>> _module_properties = create_module_properties()
|
||||
|
||||
>>> @_module_properties.getter
|
||||
>>> def _nodes():
|
||||
>>> return ...
|
||||
|
||||
This creates nodes as a property
|
||||
:return:
|
||||
"""
|
||||
properties = {}
|
||||
patched_modules = set()
|
||||
|
||||
|
||||
@ -5,6 +5,8 @@ from contextvars import ContextVar
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional, Final
|
||||
|
||||
from .cli_args import cli_args_configuration
|
||||
from .cli_args_types import Configuration
|
||||
from .component_model import cvpickle
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.folder_path_types import FolderNames
|
||||
@ -23,6 +25,7 @@ class ExecutionContext:
|
||||
list_index: Optional[int] = None
|
||||
inference_mode: bool = True
|
||||
progress_registry: Optional[AbstractProgressRegistry] = None
|
||||
configuration: Optional[Configuration] = None
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
@ -34,7 +37,7 @@ class ExecutionContext:
|
||||
yield self.list_index
|
||||
|
||||
|
||||
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
|
||||
comfyui_execution_context: Final[ContextVar] = ContextVar("comfyui_execution_context", default=ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub(), configuration=cli_args_configuration()))
|
||||
# enables context var propagation across process boundaries for process pool executors
|
||||
cvpickle.register_contextvar(comfyui_execution_context, __name__)
|
||||
|
||||
@ -52,6 +55,17 @@ def _new_execution_context(ctx: ExecutionContext):
|
||||
comfyui_execution_context.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_configuration(configuration: Optional[Configuration] = None):
|
||||
current_ctx = current_execution_context()
|
||||
if configuration is None:
|
||||
from .cli_args import cli_args_configuration
|
||||
configuration = cli_args_configuration()
|
||||
new_ctx = replace(current_ctx, configuration=configuration)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
|
||||
current_ctx = current_execution_context()
|
||||
|
||||
@ -4,10 +4,12 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from functools import reduce
|
||||
from importlib.metadata import entry_points
|
||||
from threading import RLock
|
||||
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
|
||||
@ -17,8 +19,9 @@ from comfy_api.version_list import supported_versions
|
||||
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||
from .package_typing import ExportedNodes
|
||||
from ..component_model.files import get_package_as_path
|
||||
from ..execution_context import current_execution_context
|
||||
|
||||
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||
_nodes_local = threading.local()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -46,8 +49,6 @@ def _import_nodes_in_module(module: types.ModuleType) -> ExportedNodes:
|
||||
return exported_nodes
|
||||
|
||||
|
||||
|
||||
|
||||
def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
print_import_times=False,
|
||||
raise_on_failure=False,
|
||||
@ -116,7 +117,11 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||
# now actually import the nodes, to improve control of node loading order
|
||||
from ..cli_args import args
|
||||
try:
|
||||
_nodes_available_at_startup = _nodes_local.nodes
|
||||
except (LookupError, AttributeError):
|
||||
_nodes_available_at_startup = _nodes_local.nodes = ExportedNodes()
|
||||
args = current_execution_context().configuration
|
||||
|
||||
# todo: this is some truly braindead stuff
|
||||
register_versions([
|
||||
@ -126,48 +131,47 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
||||
) for v in supported_versions
|
||||
])
|
||||
|
||||
# only load these nodes once
|
||||
if len(_nodes_available_at_startup) == 0:
|
||||
_nodes_available_at_startup.clear()
|
||||
|
||||
# import base_nodes first
|
||||
from . import base_nodes
|
||||
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
# import base_nodes first
|
||||
from . import base_nodes
|
||||
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
|
||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
||||
# this is the list of default nodes to import
|
||||
base_nodes,
|
||||
comfy_extras_nodes
|
||||
]),
|
||||
ExportedNodes())
|
||||
custom_nodes_mappings = ExportedNodes()
|
||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
||||
# this is the list of default nodes to import
|
||||
base_nodes,
|
||||
comfy_extras_nodes
|
||||
]),
|
||||
ExportedNodes())
|
||||
custom_nodes_mappings = ExportedNodes()
|
||||
|
||||
if args.disable_all_custom_nodes:
|
||||
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
||||
_nodes_available_at_startup.update(base_and_extra)
|
||||
return _nodes_available_at_startup
|
||||
if args.disable_all_custom_nodes:
|
||||
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
||||
_nodes_available_at_startup.update(base_and_extra)
|
||||
return _nodes_available_at_startup
|
||||
|
||||
# load from entrypoints
|
||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||
# Load the module associated with the current entry point
|
||||
try:
|
||||
module = entry_point.load()
|
||||
except ModuleNotFoundError as module_not_found_error:
|
||||
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||
continue
|
||||
# load from entrypoints
|
||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||
# Load the module associated with the current entry point
|
||||
try:
|
||||
module = entry_point.load()
|
||||
except ModuleNotFoundError as module_not_found_error:
|
||||
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||
continue
|
||||
|
||||
# Ensure that what we've loaded is indeed a module
|
||||
if isinstance(module, types.ModuleType):
|
||||
custom_nodes_mappings.update(
|
||||
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
|
||||
# Ensure that what we've loaded is indeed a module
|
||||
if isinstance(module, types.ModuleType):
|
||||
custom_nodes_mappings.update(
|
||||
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
|
||||
|
||||
# load the vanilla custom nodes last
|
||||
if vanilla_custom_nodes:
|
||||
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
|
||||
# load the vanilla custom nodes last
|
||||
if vanilla_custom_nodes:
|
||||
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
|
||||
|
||||
# don't allow custom nodes to overwrite base nodes
|
||||
custom_nodes_mappings -= base_and_extra
|
||||
# don't allow custom nodes to overwrite base nodes
|
||||
custom_nodes_mappings -= base_and_extra
|
||||
|
||||
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
|
||||
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
|
||||
return _nodes_available_at_startup
|
||||
|
||||
@ -198,6 +198,10 @@ class ExportedNodes:
|
||||
def __bool__(self):
|
||||
return len(self.NODE_CLASS_MAPPINGS) + len(self.NODE_DISPLAY_NAME_MAPPINGS) + len(self.EXTENSION_WEB_DIRS) > 0
|
||||
|
||||
def clear(self):
|
||||
self.NODE_CLASS_MAPPINGS.clear()
|
||||
self.EXTENSION_WEB_DIRS.clear()
|
||||
self.NODE_DISPLAY_NAME_MAPPINGS.clear()
|
||||
|
||||
class _ExportedNodesAsChainMap(ExportedNodes):
|
||||
@classmethod
|
||||
|
||||
@ -1,15 +1,26 @@
|
||||
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
||||
import threading
|
||||
|
||||
import lazy_object_proxy
|
||||
|
||||
from .execution_context import current_execution_context
|
||||
from .nodes.package import import_all_nodes_in_workspace
|
||||
from .nodes.package_typing import ExportedNodes, exported_nodes_view
|
||||
|
||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
_nodes_local = threading.local()
|
||||
|
||||
|
||||
def invalidate():
|
||||
_nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
|
||||
|
||||
def get_nodes() -> ExportedNodes:
|
||||
current_ctx = current_execution_context()
|
||||
try:
|
||||
nodes = _nodes_local.nodes
|
||||
except (LookupError, AttributeError):
|
||||
nodes = _nodes_local.nodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
|
||||
if len(current_ctx.custom_nodes) == 0:
|
||||
return nodes
|
||||
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
args_parsing = True
|
||||
|
||||
args_parsing = False
|
||||
|
||||
def enable_args_parsing(enable=True):
|
||||
global args_parsing
|
||||
args_parsing = enable
|
||||
pass
|
||||
|
||||
@ -128,12 +128,11 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
if return_metadata:
|
||||
metadata = f.metadata()
|
||||
except Exception as e:
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
|
||||
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
|
||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
|
||||
message = str(e)
|
||||
if "HeaderTooLarge" in message:
|
||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
|
||||
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
|
||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
|
||||
raise e
|
||||
elif ckpt.lower().endswith("index.json"):
|
||||
# from accelerate
|
||||
|
||||
@ -187,7 +187,7 @@ class TestExecution:
|
||||
|
||||
await client.run(g)
|
||||
mask.inputs['value'] = 0.4
|
||||
result2 = client.run(g)
|
||||
result2 = await client.run(g)
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from importlib.resources import files
|
||||
import pytest
|
||||
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
from comfy.cli_args import default_configuration
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
|
||||
@ -23,14 +24,17 @@ _TEST_WORKFLOW = {
|
||||
async def test_respect_cwd_param():
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
cwd = str(tmp_dir)
|
||||
config = Configuration(cwd=cwd)
|
||||
# for finding the custom nodes
|
||||
config.base_paths = [files(__package__)]
|
||||
config = default_configuration()
|
||||
config.cwd = cwd
|
||||
|
||||
from comfy.cmd.folder_paths import models_dir
|
||||
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
|
||||
|
||||
client = Comfy(config)
|
||||
prompt = Prompt.validate(_TEST_WORKFLOW)
|
||||
outputs = await client.queue_prompt_api(prompt)
|
||||
path_as_imported = outputs.outputs["0"]["path"][0]
|
||||
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"
|
||||
# for finding the custom nodes
|
||||
config.base_paths = [str(files(__package__))]
|
||||
|
||||
async with Comfy(config) as client:
|
||||
prompt = Prompt.validate(_TEST_WORKFLOW)
|
||||
outputs = await client.queue_prompt_api(prompt)
|
||||
path_as_imported = outputs.outputs["0"]["path"][0]
|
||||
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import pytest
|
||||
from importlib.resources import files
|
||||
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
from comfy.cli_args_types import Configuration
|
||||
import pytest
|
||||
|
||||
from comfy.cli_args import default_configuration
|
||||
from comfy.client.embedded_comfy_client import Comfy
|
||||
from comfy.execution_context import context_configuration
|
||||
|
||||
_TEST_WORKFLOW_1 = {
|
||||
"0": {
|
||||
@ -35,10 +36,15 @@ _TEST_WORKFLOW_2 = {
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blacklist_node():
|
||||
config = Configuration(blacklist_custom_nodes=['issue_46'])
|
||||
config = default_configuration()
|
||||
config.blacklist_custom_nodes = ['issue_46']
|
||||
# for finding the custom nodes
|
||||
config.base_paths = [str(files(__package__))]
|
||||
|
||||
with context_configuration(config):
|
||||
from comfy.nodes_context import get_nodes
|
||||
nodes = get_nodes()
|
||||
assert "ShouldNotExist" not in nodes.NODE_CLASS_MAPPINGS
|
||||
async with Comfy(config) as client:
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
res = await validate_prompt("1", prompt=_TEST_WORKFLOW_1, partial_execution_list=[])
|
||||
|
||||
@ -190,6 +190,8 @@ async def test_no_panic_when_disabled_with_executor(executor_cls, executor_kwarg
|
||||
patch('sys.exit') as mock_exit):
|
||||
try:
|
||||
async with Comfy(configuration=config, executor=executor) as client:
|
||||
from comfy.cli_args import args
|
||||
assert len(args.panic_when) == 0
|
||||
# Queue our failing workflow
|
||||
await client.queue_prompt(create_failing_workflow())
|
||||
except SystemExit:
|
||||
|
||||
@ -6,7 +6,7 @@ from pytest_mock import MockerFixture
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
from comfy.nodes_context import nodes
|
||||
from comfy.nodes_context import get_nodes
|
||||
|
||||
import uuid
|
||||
|
||||
@ -75,6 +75,7 @@ known_models: ContextVar[list[str]] = ContextVar('known_models', default=[])
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nodes(mocker: MockerFixture):
|
||||
nodes = get_nodes()
|
||||
class MockCheckpointLoaderSimple:
|
||||
@staticmethod
|
||||
def INPUT_TYPES():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user