From 31b6b53236c78beb714ad77f57169aa27ed4f011 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 12 Feb 2025 14:12:10 -0800 Subject: [PATCH] Quality of life improvements - export_custom_nodes() finds all the classes that inherit from CustomNode and exports them correctly for custom node discovery to find - regular expressions - additional string formatting and parsing nodes --- comfy/node_helpers.py | 46 ++++++++- comfy/nodes/package_typing.py | 4 +- comfy_extras/nodes/nodes_arithmetic.py | 67 ++++++------ comfy_extras/nodes/nodes_language.py | 27 +---- comfy_extras/nodes/nodes_regexp.py | 107 ++++++++++++++++++++ comfy_extras/nodes/nodes_strings.py | 67 ++++++++++++ tests/unit/test_regular_expression_nodes.py | 94 +++++++++++++++++ tests/unit/test_string_nodes.py | 93 +++++++++++++++++ 8 files changed, 447 insertions(+), 58 deletions(-) create mode 100644 comfy_extras/nodes/nodes_regexp.py create mode 100644 comfy_extras/nodes/nodes_strings.py create mode 100644 tests/unit/test_regular_expression_nodes.py create mode 100644 tests/unit/test_string_nodes.py diff --git a/comfy/node_helpers.py b/comfy/node_helpers.py index 13d6878a3..d98d8adef 100644 --- a/comfy/node_helpers.py +++ b/comfy/node_helpers.py @@ -1,10 +1,14 @@ import hashlib -from .cli_args import args - from PIL import ImageFile, UnidentifiedImageError -def conditioning_set_values(conditioning, values={}): +from .cli_args import args +from .nodes.package_typing import CustomNode + + +def conditioning_set_values(conditioning, values: dict = None): + if values is None: + values = {} c = [] for t in conditioning: n = [t[0], t[1].copy()] @@ -14,11 +18,12 @@ def conditioning_set_values(conditioning, values={}): return c + def pillow(fn, arg): prev_value = None try: x = fn(arg) - except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416 + except (OSError, UnidentifiedImageError, ValueError): # PIL issues #4472 and #2445, also fixes ComfyUI issue #3416 prev_value = ImageFile.LOAD_TRUNCATED_IMAGES ImageFile.LOAD_TRUNCATED_IMAGES = True x = fn(arg) @@ -27,6 +32,7 @@ def pillow(fn, arg): ImageFile.LOAD_TRUNCATED_IMAGES = prev_value return x + def hasher(): hashfuncs = { "md5": hashlib.md5, @@ -35,3 +41,35 @@ def hasher(): "sha512": hashlib.sha512 } return hashfuncs[args.default_hashing_function] + + +def export_custom_nodes(): + """ + Finds all classes in the current module that extend CustomNode and creates + a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects. + Must be called from within the module where the CustomNode classes are defined. + """ + import inspect + + # Get the calling module + frame = inspect.currentframe() + try: + module = inspect.getmodule(frame.f_back) + + custom_nodes = {} + for name, obj in inspect.getmembers(module): + if (inspect.isclass(obj) and + CustomNode in obj.__mro__ and + obj != CustomNode): + custom_nodes[name] = obj + if hasattr(module, 'NODE_CLASS_MAPPINGS'): + node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS') + node_class_mappings.update(custom_nodes) + else: + setattr(module, 'NODE_CLASS_MAPPINGS', custom_nodes) + + finally: + # Clean up circular reference + del frame + + return custom_nodes diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index ee7c975bb..d83c37830 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \ - Callable, List, Type + Callable, List, Type, runtime_checkable from typing_extensions import TypedDict, NotRequired @@ -64,7 +64,7 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any] InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec] # numpy seeds must be between 0 and 2**32 - 1 -Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1}) +Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1}) Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}) SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})] diff --git a/comfy_extras/nodes/nodes_arithmetic.py b/comfy_extras/nodes/nodes_arithmetic.py index fd0418dd2..adef08472 100644 --- a/comfy_extras/nodes/nodes_arithmetic.py +++ b/comfy_extras/nodes/nodes_arithmetic.py @@ -1,6 +1,9 @@ from functools import reduce from operator import add, mul, pow +from typing import Any +from comfy.comfy_types import IO +from comfy.node_helpers import export_custom_nodes from comfy.nodes.package_typing import CustomNode, InputTypes @@ -527,32 +530,38 @@ class FloatToInt(CustomNode): return int(value), -NODE_CLASS_MAPPINGS = {} -for cls in ( - FloatAdd, - FloatSubtract, - FloatMultiply, - FloatDivide, - FloatPower, - FloatMin, - FloatMax, - FloatAbs, - FloatAverage, - FloatLerp, - FloatInverseLerp, - FloatClamp, - IntAdd, - IntSubtract, - IntMultiply, - IntDivide, - IntMod, - IntPower, - IntMin, - IntMax, - IntAbs, - IntAverage, - IntLerp, - IntInverseLerp, - IntClamp, -): - NODE_CLASS_MAPPINGS[cls.__name__] = cls +class StringToInt(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, value: str = "0"): + return int(value), + + +class StringToFloat(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {}), + } + } + + CATEGORY = "arithmetic" + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, value: str = "0"): + return float(value), + + +export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index 591a01aaf..f045dc127 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -20,6 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device +from comfy.node_helpers import export_custom_nodes from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed _AUTO_CHAT_TEMPLATE = "default" @@ -370,7 +371,7 @@ class PreviewString(CustomNode): } } - CATEGORY = "language" + CATEGORY = "strings" FUNCTION = "execute" RETURN_TYPES = ("STRING",) OUTPUT_NODE = True @@ -392,7 +393,7 @@ class SaveString(CustomNode): } } - CATEGORY = "language" + CATEGORY = "strings" FUNCTION = "execute" OUTPUT_NODE = True RETURN_TYPES = () @@ -412,24 +413,4 @@ class SaveString(CustomNode): return {"ui": {"string": value}} -NODE_CLASS_MAPPINGS = {} -for cls in ( - TransformerTopKSampler, - TransformerTopPSampler, - TransformerTemperatureSampler, - TransformerGreedySampler, - TransformerContrastiveSearchSampler, - TransformerBeamSearchSampler, - TransformerMergeSamplers, - TransformersLoader, - TransformersImageProcessorLoader, - TransformersGenerate, - OneShotInstructTokenize, - TransformersM2M100LanguageCodes, - TransformersTokenize, - TransformersFlores200LanguageCodes, - TransformersTranslationTokenize, - PreviewString, - SaveString, -): - NODE_CLASS_MAPPINGS[cls.__name__] = cls +export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_regexp.py b/comfy_extras/nodes/nodes_regexp.py new file mode 100644 index 000000000..0c6dba43a --- /dev/null +++ b/comfy_extras/nodes/nodes_regexp.py @@ -0,0 +1,107 @@ +import re + +from comfy.node_helpers import export_custom_nodes +from comfy.nodes.package_typing import CustomNode, InputTypes + +MATCH_TYPE_NAME = "MATCH" + + +class RegexFlags(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + flags_ = {flag_name: ("BOOLEAN", {}) for flag_name in ["ASCII", "IGNORECASE", "LOCALE", "MULTILINE", "DOTALL", "VERBOSE", + "UNICODE", "NOFLAG"]} + return { + "required": flags_, + } + + CATEGORY = "regular_expressions" + FUNCTION = "execute" + RETURN_TYPES = ("INT",) + + def execute(self, **kwargs) -> tuple[int]: + flags = re.RegexFlag.NOFLAG + + for name, on in kwargs.items(): + if on: + flags |= re.RegexFlag[name] + + return int(flags), + + +class Regex(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "pattern": ("STRING", {}), + "string": ("STRING", {}), + }, + "optional": { + "flags": ("INT", {"min": 0, "default": 0}) + } + } + + CATEGORY = "regular_expressions" + FUNCTION = "execute" + RETURN_TYPES = (MATCH_TYPE_NAME,) + + def execute(self, pattern: str = "", string: str = "", flags: int = 0) -> tuple[re.Match]: + return re.match(pattern=pattern, string=string, flags=flags), + + +class RegexMatchGroupByIndex(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "match": (MATCH_TYPE_NAME, {}), + "index": ("INT", {}), + }, + } + + CATEGORY = "regular_expressions" + FUNCTION = "execute" + RETURN_TYPES = (MATCH_TYPE_NAME,) + + def execute(self, match: re.Match, index: int = 0) -> tuple[str]: + return match.group(index), + + +class RegexMatchGroupByName(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "match": (MATCH_TYPE_NAME, {}), + "name": ("STRING", {}), + }, + } + + CATEGORY = "regular_expressions" + FUNCTION = "execute" + RETURN_TYPES = (MATCH_TYPE_NAME,) + + def execute(self, match: re.Match, name: str = "") -> tuple[str]: + return match.group(name), + + +class RegexMatchExpand(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "match": (MATCH_TYPE_NAME, {}), + "template": ("STRING", {}), + }, + } + + CATEGORY = "regular_expressions" + FUNCTION = "execute" + RETURN_TYPES = (MATCH_TYPE_NAME,) + + def execute(self, match: re.Match, template: str = "") -> tuple[str]: + return match.expand(template), + + +export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_strings.py b/comfy_extras/nodes/nodes_strings.py new file mode 100644 index 000000000..c28f8a3b7 --- /dev/null +++ b/comfy_extras/nodes/nodes_strings.py @@ -0,0 +1,67 @@ +from typing import Any + +import torch +from natsort import natsorted + +from comfy.comfy_types import IO +from comfy.node_helpers import export_custom_nodes +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult + + +def format_value(value: Any) -> str | Any: + """Helper function to format values for string formatting.""" + if value is None: + return "None" + elif isinstance(value, torch.Tensor): + if value.numel() > 10: # For large tensors + shape_str = 'x'.join(str(x) for x in value.shape) + return f"" + else: + return str(value.tolist()) + return value + + +class StringFormat(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + optional = {f"value{i}": (IO.ANY, {"default": "", "forceInput": True}) for i in range(5)} + optional["format"] = (IO.STRING, {"default": "{}", "multiline": True}) + return { + "required": {}, + "optional": optional + } + + RETURN_TYPES = ("STRING",) + CATEGORY = "strings" + FUNCTION = "execute" + + def execute(self, format: str = "{}", *args: Any, **kwargs) -> ValidatedNodeResult: + for k, v in kwargs.items(): + kwargs[k] = format_value(v) + try: + if any(f'{{{k}' in format for k in kwargs.keys()): + return (format.format(**kwargs),) + else: + return (format.format(*[kwargs[k] for k in natsorted(kwargs.keys())]),) + except (IndexError, KeyError) as e: + return ("Format error: " + str(e),) + + +class ToString(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": (IO.ANY, {}), + } + } + + CATEGORY = "strings" + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + + def execute(self, value: Any): + return str(value), + + +export_custom_nodes() diff --git a/tests/unit/test_regular_expression_nodes.py b/tests/unit/test_regular_expression_nodes.py new file mode 100644 index 000000000..bacd2801b --- /dev/null +++ b/tests/unit/test_regular_expression_nodes.py @@ -0,0 +1,94 @@ +import re + +import pytest + +from comfy_extras.nodes.nodes_regexp import RegexFlags, Regex, RegexMatchGroupByIndex, RegexMatchGroupByName, \ + RegexMatchExpand + + +def test_regex_flags(): + n = RegexFlags() + + # Test with no flags + flags, = n.execute(ASCII=False, IGNORECASE=False, LOCALE=False, MULTILINE=False, + DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=True) + assert flags == 0 + + # Test single flag + flags, = n.execute(ASCII=True, IGNORECASE=False, LOCALE=False, MULTILINE=False, + DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False) + assert flags == re.ASCII + + # Test multiple flags + flags, = n.execute(ASCII=True, IGNORECASE=True, LOCALE=False, MULTILINE=False, + DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False) + assert flags == (re.ASCII | re.IGNORECASE) + + +def test_regex(): + n = Regex() + + # Basic match test + match, = n.execute(pattern=r"hello", string="hello world") + assert match is not None + assert match.group(0) == "hello" + + # Test with flags + match, = n.execute(pattern=r"HELLO", string="hello world", flags=re.IGNORECASE) + assert match is not None + assert match.group(0) == "hello" + + # Test no match + match, = n.execute(pattern=r"python", string="hello world") + assert match is None + + +def test_regex_match_group_by_index(): + n = RegexMatchGroupByIndex() + regex = Regex() + + # Test basic group + match, = regex.execute(pattern=r"(hello) (world)", string="hello world") + group, = n.execute(match=match, index=0) + assert group == "hello world" + + group, = n.execute(match=match, index=1) + assert group == "hello" + + group, = n.execute(match=match, index=2) + assert group == "world" + + +def test_regex_match_group_by_name(): + n = RegexMatchGroupByName() + regex = Regex() + + # Test named group + match, = regex.execute(pattern=r"(?Phello) (?Pworld)", + string="hello world") + + group, = n.execute(match=match, name="greeting") + assert group == "hello" + + group, = n.execute(match=match, name="subject") + assert group == "world" + + # Test non-existent group name + with pytest.raises(IndexError): + n.execute(match=match, name="nonexistent") + + +def test_regex_match_expand(): + n = RegexMatchExpand() + regex = Regex() + + # Test basic expansion + match, = regex.execute(pattern=r"(hello) (world)", string="hello world") + result, = n.execute(match=match, template=r"\2, \1!") + assert result == "world, hello!" + + # Test named group expansion + match, = regex.execute(pattern=r"(?Phello) (?Pworld)", + string="hello world") + result, = n.execute(match=match, template=r"\g, \g!") + assert result == "world, hello!" diff --git a/tests/unit/test_string_nodes.py b/tests/unit/test_string_nodes.py new file mode 100644 index 000000000..ef4a9fe72 --- /dev/null +++ b/tests/unit/test_string_nodes.py @@ -0,0 +1,93 @@ +import torch + +from comfy_extras.nodes.nodes_strings import StringFormat + + +def test_string_format_basic(): + n = StringFormat() + + # Test basic string formatting + result, = n.execute(format="Hello, {}!", value0="World") + assert result == "Hello, World!" + + # Test multiple values + result, = n.execute(format="{} plus {} equals {}", value0=2, value1=2, value2=4) + assert result == "2 plus 2 equals 4" + + +def test_string_format_types(): + n = StringFormat() + + # Test with different types + result, = n.execute(format="Float: {:.2f}, Int: {}, Bool: {}", + value0=3.14159, value1=42, value2=True) + assert result == "Float: 3.14, Int: 42, Bool: True" + + # Test None values + result, = n.execute(format="{}, {}, {}", value0=None, value1="test", value2=None) + assert result == "None, test, None" + + +def test_string_format_tensors(): + n = StringFormat() + + # Test small tensor + small_tensor = torch.tensor([1, 2, 3]) + result, = n.execute(format="Tensor: {}", value0=small_tensor) + assert result == "Tensor: [1, 2, 3]" + + # Test large tensor + large_tensor = torch.randn(100, 100) + result, = n.execute(format="Large tensor: {}", value0=large_tensor) + assert result == "Large tensor: " + + # Test mixed tensor sizes + small_tensor = torch.tensor([1, 2]) + large_tensor = torch.randn(50, 50) + result, = n.execute(format="{} and {}", value0=small_tensor, value1=large_tensor) + assert "and " in result + assert "[1, 2]" in result + + +def test_string_format_edge_cases(): + n = StringFormat() + + # Test with missing values + result, = n.execute(format="{} {} {}", value0="a", value1="b") + assert result.startswith("Format error: ") + + # Test with empty format string + result, = n.execute(format="", value0="ignored") + assert result == "" + + # Test with no placeholders + result, = n.execute(format="Hello World", value0="ignored") + assert result == "Hello World" + + # Test with named placeholders + result, = n.execute(format="X: {value0}, Y: {value1}", value0=10, value1=20) + assert result == "X: 10, Y: 20" + + # Test mixing None, tensors and regular values + tensor = torch.tensor([1, 2, 3]) + result, = n.execute(format="{}, {}, {}", value0=None, value1=tensor, value2="test") + assert result == "None, [1, 2, 3], test" + + +def test_string_format_tensor_edge_cases(): + n = StringFormat() + + # Test empty tensor + empty_tensor = torch.tensor([]) + result, = n.execute(format="Empty tensor: {}", value0=empty_tensor) + assert result == "Empty tensor: []" + + # Test scalar tensor + scalar_tensor = torch.tensor(5) + result, = n.execute(format="Scalar tensor: {}", value0=scalar_tensor) + assert result == "Scalar tensor: 5" + + # Test multi-dimensional small tensor + small_2d_tensor = torch.tensor([[1, 2], [3, 4]]) + result, = n.execute(format="2D tensor: {}", value0=small_2d_tensor) + assert result == "2D tensor: [[1, 2], [3, 4]]"