From 3a56b15bc22a906976b2a7a2395711eb6166a9e4 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 24 Aug 2023 22:23:12 -0700 Subject: [PATCH] use comfyui.custom_nodes as the plugin entrypoint and fix the protocol --- comfy/nodes/package.py | 17 +++++++++++++---- comfy/nodes/package_typing.py | 8 ++++---- setup.py | 1 - 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 90c5008f5..118a29087 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -15,7 +15,7 @@ except: custom_nodes = None from .package_typing import ExportedNodes from functools import reduce -from pkg_resources import resource_filename +from pkg_resources import resource_filename, iter_entry_points _comfy_nodes = ExportedNodes() @@ -85,10 +85,19 @@ def import_all_nodes_in_workspace() -> ExportedNodes: ExportedNodes()) custom_nodes_mappings = ExportedNodes() if custom_nodes is not None: - custom_nodes_mappings = _import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True) + custom_nodes_mappings.update(_import_and_enumerate_nodes_in_module(custom_nodes, print_import_times=True)) - # don't allow custom nodes to overwrite base nodes - custom_nodes_mappings -= base_and_extra + # load from entrypoints + for entry_point in iter_entry_points(group='comfyui.custom_nodes'): + # Load the module associated with the current entry point + module = entry_point.load() + + # Ensure that what we've loaded is indeed a module + if isinstance(module, types.ModuleType): + custom_nodes_mappings.update( + _import_and_enumerate_nodes_in_module(module, print_import_times=True)) + # don't allow custom nodes to overwrite base nodes + custom_nodes_mappings -= base_and_extra _comfy_nodes.update(base_and_extra + custom_nodes_mappings) return _comfy_nodes diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 6cb8d212f..449750f6b 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -10,12 +10,12 @@ class CustomNode(Protocol): def INPUT_TYPES(cls) -> dict: ... RETURN_TYPES: ClassVar[typing.Sequence[str]] - RETURN_NAMES: ClassVar[Tuple[str]] = None - OUTPUT_IS_LIST: ClassVar[Tuple[bool]] = None - INPUT_IS_LIST: ClassVar[bool] = None + RETURN_NAMES: typing.Optional[ClassVar[Tuple[str]]] + OUTPUT_IS_LIST: typing.Optional[ClassVar[typing.Sequence[bool]]] + INPUT_IS_LIST: typing.Optional[ClassVar[bool]] FUNCTION: ClassVar[str] CATEGORY: ClassVar[str] - OUTPUT_NODE: ClassVar[bool] = None + OUTPUT_NODE: typing.Optional[ClassVar[bool]] @dataclass diff --git a/setup.py b/setup.py index 4ff004cc1..ff4a3d055 100644 --- a/setup.py +++ b/setup.py @@ -156,7 +156,6 @@ package_data = ['sd1_tokenizer/*', '**/*.json', '**/*.yaml'] if not is_editable: package_data.append('web/**/*') setup( - # "comfy" name=package_name, description="", author="",