export_custom_nodes now handles abstract base classes better

This commit is contained in:
doctorpangloss 2025-02-14 15:36:51 -08:00
parent f4e65590b8
commit 0ca30c3c87
2 changed files with 12 additions and 3 deletions

View File

@ -45,11 +45,13 @@ def hasher():
def export_custom_nodes(): def export_custom_nodes():
""" """
Finds all classes in the current module that extend CustomNode and creates Finds all non-abstract classes in the current module that extend CustomNode and creates
a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects. a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects.
Must be called from within the module where the CustomNode classes are defined. Must be called from within the module where the CustomNode classes are defined.
""" """
import inspect import inspect
from abc import ABC
from comfy.nodes.package_typing import CustomNode
# Get the calling module # Get the calling module
frame = inspect.currentframe() frame = inspect.currentframe()
@ -60,7 +62,8 @@ def export_custom_nodes():
for name, obj in inspect.getmembers(module): for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and if (inspect.isclass(obj) and
CustomNode in obj.__mro__ and CustomNode in obj.__mro__ and
obj != CustomNode): obj != CustomNode and
not inspect.isabstract(obj)):
custom_nodes[name] = obj custom_nodes[name] = obj
if hasattr(module, 'NODE_CLASS_MAPPINGS'): if hasattr(module, 'NODE_CLASS_MAPPINGS'):
node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS') node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS')

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import operator import operator
import os.path import os.path
from abc import ABC, abstractmethod
from functools import reduce from functools import reduce
from typing import Optional, List from typing import Optional, List
@ -26,12 +27,17 @@ from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResu
_AUTO_CHAT_TEMPLATE = "default" _AUTO_CHAT_TEMPLATE = "default"
class TransformerSamplerBase(CustomNode): class TransformerSamplerBase(CustomNode, ABC):
RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME, RETURN_TYPES = GENERATION_KWARGS_TYPE_NAME,
RETURN_NAMES = "GENERATION ARGS", RETURN_NAMES = "GENERATION ARGS",
FUNCTION = "execute" FUNCTION = "execute"
CATEGORY = "language/samplers" CATEGORY = "language/samplers"
@classmethod
@abstractmethod
def INPUT_TYPES(cls) -> InputTypes:
return ...
@property @property
def do_sample(self): def do_sample(self):
return True return True