From 0ca30c3c87a72a8ed396357accafe7b0b131690f Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 14 Feb 2025 15:36:51 -0800 Subject: [PATCH] export_custom_nodes now handles abstract base classes better --- comfy/node_helpers.py | 7 +++++-- comfy_extras/nodes/nodes_language.py | 8 +++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/comfy/node_helpers.py b/comfy/node_helpers.py index d98d8adef..ca6e67266 100644 --- a/comfy/node_helpers.py +++ b/comfy/node_helpers.py @@ -45,11 +45,13 @@ def hasher(): def export_custom_nodes(): """ - Finds all classes in the current module that extend CustomNode and creates + Finds all non-abstract classes in the current module that extend CustomNode and creates a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects. Must be called from within the module where the CustomNode classes are defined. """ import inspect + from abc import ABC + from comfy.nodes.package_typing import CustomNode # Get the calling module frame = inspect.currentframe() @@ -60,7 +62,8 @@ def export_custom_nodes(): for name, obj in inspect.getmembers(module): if (inspect.isclass(obj) and CustomNode in obj.__mro__ and - obj != CustomNode): + obj != CustomNode and + not inspect.isabstract(obj)): custom_nodes[name] = obj if hasattr(module, 'NODE_CLASS_MAPPINGS'): node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS') diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index f045dc127..1ea122fbe 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -2,6 +2,7 @@ from __future__ import annotations import operator import os.path +from abc import ABC, abstractmethod from functools import reduce from typing import Optional, List @@ -26,12 +27,17 @@ from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResu _AUTO_CHAT_TEMPLATE = "default" -class TransformerSamplerBase(CustomNode): +class TransformerSamplerBase(CustomNode, ABC): RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME, RETURN_NAMES = "GENERATION ARGS", FUNCTION = "execute" CATEGORY = "language/samplers" + @classmethod + @abstractmethod + def INPUT_TYPES(cls) -> InputTypes: + return ... + @property def do_sample(self): return True