diff --git a/README.md b/README.md index b49b09f3c..3af7d5acc 100644 --- a/README.md +++ b/README.md @@ -147,10 +147,6 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from ```shell comfyui ``` - To generate python OpenAPI models: - ```shell - comfyui-openapi-gen - ``` To run tests: ```shell pytest tests/inference diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index f9a0fa2d6..e39073d66 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -10,21 +10,18 @@ import traceback import typing from typing import List, Optional, Tuple, Union from typing_extensions import TypedDict - -# Suppress warnings during import -import warnings -warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") - import torch from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.executor_types import ExecutorToClientProgress +from .. import model_management +from ..nodes.package_typing import ExportedNodes from ..nodes.package import import_all_nodes_in_workspace -nodes = import_all_nodes_in_workspace() -from .. import model_management # type: ignore - +# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the +# various functions that are declared here. It should have been a context in the first place. +nodes: ExportedNodes = import_all_nodes_in_workspace() def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None): if extra_data is None: @@ -88,7 +85,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): results.append(getattr(obj, func)(**input_data_all)) elif max_len_input == 0: if allow_interrupt: - nodes.before_node_execution() + model_management.throw_exception_if_processing_interrupted() results.append(getattr(obj, func)()) else: for i in range(max_len_input): @@ -141,7 +138,14 @@ def format_value(x): return str(x) -def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, +def recursive_execute(server: ExecutorToClientProgress, + prompt, + outputs, + current_item, + extra_data, + executed, + prompt_id, + outputs_ui, object_storage): unique_id = current_item inputs = prompt[unique_id]['inputs'] diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 1385c8ae7..d84f715ff 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,5 +1,8 @@ from .. import options +# Suppress warnings during import +import warnings +warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.") options.enable_args_parsing() import os @@ -9,6 +12,8 @@ from ..cmd import cuda_malloc from ..cmd import folder_paths from .extra_model_paths import load_extra_path_config from ..analytics.analytics import initialize_event_tracking +from ..nodes.package import import_all_nodes_in_workspace + import time @@ -77,7 +82,6 @@ if args.deterministic: from .. import utils -from ..cmd import execution from ..cmd import server as server_module from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus @@ -88,7 +92,9 @@ from ..distributed.server_stub import ServerStub def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): - e = execution.PromptExecutor(_server) + from ..cmd.execution import PromptExecutor + + e = PromptExecutor(_server) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -202,10 +208,27 @@ async def main(): except: pass + # configure extra model paths earlier + try: + extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") + if os.path.isfile(extra_model_paths_config_path): + load_extra_path_config(extra_model_paths_config_path) + except NameError: + pass + + if args.extra_model_paths_config: + for config_path in itertools.chain(*args.extra_model_paths_config): + load_extra_path_config(config_path) + loop = asyncio.get_event_loop() server = server_module.PromptServer(loop) if args.external_address is not None: server.external_address = args.external_address + + # at this stage, it's safe to import nodes + server.nodes = import_all_nodes_in_workspace() + # as a side effect, this also populates the nodes for execution + if args.distributed_queue_connection_uri is not None: distributed = True q = DistributedPromptQueue( @@ -219,19 +242,10 @@ async def main(): await q.init() else: distributed = False - q = execution.PromptQueue(server) + from execution import PromptQueue + q = PromptQueue(server) server.prompt_queue = q - try: - extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") - if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) - except NameError: - pass - - if args.extra_model_paths_config: - for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) server.add_routes() hijack_progress(server) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index adf3d6790..3eaa64353 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -34,10 +34,8 @@ from .. import model_management from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.file_output_path import file_output_path from ..nodes.package import import_all_nodes_in_workspace +from ..nodes.package_typing import ExportedNodes from ..vendor.appdirs import user_data_dir - -nodes = import_all_nodes_in_workspace() - from ..app.user_manager import UserManager @@ -114,6 +112,7 @@ class PromptServer(ExecutorToClientProgress): self.client_id = None self.on_prompt_handlers = [] + self.nodes: ExportedNodes = ExportedNodes() @routes.get('/ws') async def websocket_handler(request): @@ -157,7 +156,7 @@ class PromptServer(ExecutorToClientProgress): glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) - for name, dir in nodes.EXTENSION_WEB_DIRS.items(): + for name, dir in self.nodes.EXTENSION_WEB_DIRS.items(): files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) extensions.extend(list(map(lambda f: "/extensions/" + quote( name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) @@ -398,7 +397,7 @@ class PromptServer(ExecutorToClientProgress): return web.json_response(self.get_queue_info()) def node_info(node_class): - obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES @@ -407,8 +406,8 @@ class PromptServer(ExecutorToClientProgress): obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = node_class - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[ - node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[ + node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else '' info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: @@ -423,7 +422,7 @@ class PromptServer(ExecutorToClientProgress): @routes.get("/object_info") async def get_object_info(request): out = {} - for x in nodes.NODE_CLASS_MAPPINGS: + for x in self.nodes.NODE_CLASS_MAPPINGS: try: out[x] = node_info(x) except Exception as e: @@ -436,7 +435,7 @@ class PromptServer(ExecutorToClientProgress): async def get_object_info_node(request): node_class = request.match_info.get("node_class", None) out = {} - if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): + if (node_class is not None) and (node_class in self.nodes.NODE_CLASS_MAPPINGS): out[node_class] = node_info(node_class) return web.json_response(out) @@ -710,7 +709,7 @@ class PromptServer(ExecutorToClientProgress): self.user_manager.add_routes(self.routes) self.app.add_routes(self.routes) - for name, dir in nodes.EXTENSION_WEB_DIRS.items(): + for name, dir in self.nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([ web.static('/extensions/' + quote(name), dir), ]) diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index a7e6e5fb2..75ab5b853 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -16,7 +16,7 @@ from . import base_nodes from .package_typing import ExportedNodes from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes -_comfy_nodes = ExportedNodes() +_comfy_nodes: ExportedNodes = ExportedNodes() def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): @@ -77,7 +77,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: - global _comfy_nodes + # only load these nodes once if len(_comfy_nodes) == 0: base_and_extra = reduce(lambda x, y: x.update(y), map(_import_and_enumerate_nodes_in_module, [ @@ -88,9 +88,6 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: ExportedNodes()) custom_nodes_mappings = ExportedNodes() - if vanilla_custom_nodes: - custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes() - # load from entrypoints for entry_point in entry_points().select(group='comfyui.custom_nodes'): # Load the module associated with the current entry point @@ -100,6 +97,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes: if isinstance(module, types.ModuleType): custom_nodes_mappings.update( _import_and_enumerate_nodes_in_module(module, print_import_times=True)) + + # load the vanilla custom nodes last + if vanilla_custom_nodes: + custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes() + # don't allow custom nodes to overwrite base nodes custom_nodes_mappings -= base_and_extra diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 2313322dd..3414fca89 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -124,12 +124,12 @@ class ExportedNodes: def __sub__(self, other: ExportedNodes): exported_nodes = ExportedNodes().update(self) - for self_key in exported_nodes.NODE_CLASS_MAPPINGS: + for self_key in frozenset(exported_nodes.NODE_CLASS_MAPPINGS): if self_key in other.NODE_CLASS_MAPPINGS: exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key) if self_key in other.NODE_DISPLAY_NAME_MAPPINGS: exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key) - for self_key in exported_nodes.EXTENSION_WEB_DIRS: + for self_key in frozenset(exported_nodes.EXTENSION_WEB_DIRS): if self_key in other.EXTENSION_WEB_DIRS: exported_nodes.EXTENSION_WEB_DIRS.pop(self_key) return exported_nodes diff --git a/comfy/nodes/vanilla_node_importing.py b/comfy/nodes/vanilla_node_importing.py index 94be1679e..911546bbb 100644 --- a/comfy/nodes/vanilla_node_importing.py +++ b/comfy/nodes/vanilla_node_importing.py @@ -89,17 +89,31 @@ def _vanilla_load_custom_nodes_2() -> ExportedNodes: def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes: # only vanilla custom nodes will ever go into the custom_nodes directory - # this mitigation puts files that custom nodes expects are at the root of the repository back where they should be found - from ..cmd import cuda_malloc, folder_paths, execution, server, latent_preview - for module in (cuda_malloc, folder_paths, execution, server, latent_preview): + # this mitigation puts files that custom nodes expects are at the root of the repository back where they should be + # found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things, + # the way community custom nodes is pretty radioactive + from ..cmd import cuda_malloc, folder_paths, latent_preview + for module in (cuda_malloc, folder_paths, latent_preview): module_short_name = module.__name__.split(".")[-1] sys.modules[module_short_name] = module sys.modules['nodes'] = base_nodes + from ..cmd import execution, server + for module in (execution, server): + module_short_name = module.__name__.split(".")[-1] + sys.modules[module_short_name] = module + + # Impact Pack wants to find model_patcher + from .. import model_patcher + sys.modules['model_patcher'] = model_patcher + comfy_extras_mitigation: Dict[str, types.ModuleType] = {} + + import comfy_extras for module_name, module in sys.modules.items(): if not module_name.startswith("comfy_extras.nodes"): continue module_short_name = module_name.split(".")[-1] + setattr(comfy_extras, module_short_name, module) comfy_extras_mitigation[f'comfy_extras.{module_short_name}'] = module sys.modules.update(comfy_extras_mitigation) vanilla_custom_nodes = _vanilla_load_custom_nodes_2() diff --git a/setup.py b/setup.py index 89a03bbcc..4b50080ed 100644 --- a/setup.py +++ b/setup.py @@ -181,7 +181,6 @@ setup( setup_requires=["pip", "wheel"], entry_points={ 'console_scripts': [ - 'comfyui-openapi-gen = comfy.cmd.openapi_gen:main', 'comfyui = comfy.cmd.main:entrypoint', 'comfyui-worker = comfy.cmd.worker:entrypoint' ],