Quality of life improvements

- export_custom_nodes() finds all the classes that inherit from
   CustomNode and exports them correctly for custom node discovery to
   find
 - regular expressions
 - additional string formatting and parsing nodes
This commit is contained in:
doctorpangloss 2025-02-12 14:12:10 -08:00
parent cf08b11132
commit 31b6b53236
8 changed files with 447 additions and 58 deletions

View File

@ -1,10 +1,14 @@
import hashlib
from .cli_args import args
from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}):
from .cli_args import args
from .nodes.package_typing import CustomNode
def conditioning_set_values(conditioning, values: dict = None):
if values is None:
values = {}
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
@ -14,11 +18,12 @@ def conditioning_set_values(conditioning, values={}):
return c
def pillow(fn, arg):
prev_value = None
try:
x = fn(arg)
except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
except (OSError, UnidentifiedImageError, ValueError): # PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
ImageFile.LOAD_TRUNCATED_IMAGES = True
x = fn(arg)
@ -27,6 +32,7 @@ def pillow(fn, arg):
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
return x
def hasher():
hashfuncs = {
"md5": hashlib.md5,
@ -35,3 +41,35 @@ def hasher():
"sha512": hashlib.sha512
}
return hashfuncs[args.default_hashing_function]
def export_custom_nodes():
"""
Finds all classes in the current module that extend CustomNode and creates
a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects.
Must be called from within the module where the CustomNode classes are defined.
"""
import inspect
# Get the calling module
frame = inspect.currentframe()
try:
module = inspect.getmodule(frame.f_back)
custom_nodes = {}
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
CustomNode in obj.__mro__ and
obj != CustomNode):
custom_nodes[name] = obj
if hasattr(module, 'NODE_CLASS_MAPPINGS'):
node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS')
node_class_mappings.update(custom_nodes)
else:
setattr(module, 'NODE_CLASS_MAPPINGS', custom_nodes)
finally:
# Clean up circular reference
del frame
return custom_nodes

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type
Callable, List, Type, runtime_checkable
from typing_extensions import TypedDict, NotRequired
@ -64,7 +64,7 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
# numpy seeds must be between 0 and 2**32 - 1
Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1})
Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1})
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff})
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]

View File

@ -1,6 +1,9 @@
from functools import reduce
from operator import add, mul, pow
from typing import Any
from comfy.comfy_types import IO
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes
@ -527,32 +530,38 @@ class FloatToInt(CustomNode):
return int(value),
NODE_CLASS_MAPPINGS = {}
for cls in (
FloatAdd,
FloatSubtract,
FloatMultiply,
FloatDivide,
FloatPower,
FloatMin,
FloatMax,
FloatAbs,
FloatAverage,
FloatLerp,
FloatInverseLerp,
FloatClamp,
IntAdd,
IntSubtract,
IntMultiply,
IntDivide,
IntMod,
IntPower,
IntMin,
IntMax,
IntAbs,
IntAverage,
IntLerp,
IntInverseLerp,
IntClamp,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls
class StringToInt(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, value: str = "0"):
return int(value),
class StringToFloat(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
}
}
CATEGORY = "arithmetic"
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, value: str = "0"):
return float(value),
export_custom_nodes()

View File

@ -20,6 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
_AUTO_CHAT_TEMPLATE = "default"
@ -370,7 +371,7 @@ class PreviewString(CustomNode):
}
}
CATEGORY = "language"
CATEGORY = "strings"
FUNCTION = "execute"
RETURN_TYPES = ("STRING",)
OUTPUT_NODE = True
@ -392,7 +393,7 @@ class SaveString(CustomNode):
}
}
CATEGORY = "language"
CATEGORY = "strings"
FUNCTION = "execute"
OUTPUT_NODE = True
RETURN_TYPES = ()
@ -412,24 +413,4 @@ class SaveString(CustomNode):
return {"ui": {"string": value}}
NODE_CLASS_MAPPINGS = {}
for cls in (
TransformerTopKSampler,
TransformerTopPSampler,
TransformerTemperatureSampler,
TransformerGreedySampler,
TransformerContrastiveSearchSampler,
TransformerBeamSearchSampler,
TransformerMergeSamplers,
TransformersLoader,
TransformersImageProcessorLoader,
TransformersGenerate,
OneShotInstructTokenize,
TransformersM2M100LanguageCodes,
TransformersTokenize,
TransformersFlores200LanguageCodes,
TransformersTranslationTokenize,
PreviewString,
SaveString,
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls
export_custom_nodes()

View File

@ -0,0 +1,107 @@
import re
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes
MATCH_TYPE_NAME = "MATCH"
class RegexFlags(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
flags_ = {flag_name: ("BOOLEAN", {}) for flag_name in ["ASCII", "IGNORECASE", "LOCALE", "MULTILINE", "DOTALL", "VERBOSE",
"UNICODE", "NOFLAG"]}
return {
"required": flags_,
}
CATEGORY = "regular_expressions"
FUNCTION = "execute"
RETURN_TYPES = ("INT",)
def execute(self, **kwargs) -> tuple[int]:
flags = re.RegexFlag.NOFLAG
for name, on in kwargs.items():
if on:
flags |= re.RegexFlag[name]
return int(flags),
class Regex(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"pattern": ("STRING", {}),
"string": ("STRING", {}),
},
"optional": {
"flags": ("INT", {"min": 0, "default": 0})
}
}
CATEGORY = "regular_expressions"
FUNCTION = "execute"
RETURN_TYPES = (MATCH_TYPE_NAME,)
def execute(self, pattern: str = "", string: str = "", flags: int = 0) -> tuple[re.Match]:
return re.match(pattern=pattern, string=string, flags=flags),
class RegexMatchGroupByIndex(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"match": (MATCH_TYPE_NAME, {}),
"index": ("INT", {}),
},
}
CATEGORY = "regular_expressions"
FUNCTION = "execute"
RETURN_TYPES = (MATCH_TYPE_NAME,)
def execute(self, match: re.Match, index: int = 0) -> tuple[str]:
return match.group(index),
class RegexMatchGroupByName(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"match": (MATCH_TYPE_NAME, {}),
"name": ("STRING", {}),
},
}
CATEGORY = "regular_expressions"
FUNCTION = "execute"
RETURN_TYPES = (MATCH_TYPE_NAME,)
def execute(self, match: re.Match, name: str = "") -> tuple[str]:
return match.group(name),
class RegexMatchExpand(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"match": (MATCH_TYPE_NAME, {}),
"template": ("STRING", {}),
},
}
CATEGORY = "regular_expressions"
FUNCTION = "execute"
RETURN_TYPES = (MATCH_TYPE_NAME,)
def execute(self, match: re.Match, template: str = "") -> tuple[str]:
return match.expand(template),
export_custom_nodes()

View File

@ -0,0 +1,67 @@
from typing import Any
import torch
from natsort import natsorted
from comfy.comfy_types import IO
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
def format_value(value: Any) -> str | Any:
"""Helper function to format values for string formatting."""
if value is None:
return "None"
elif isinstance(value, torch.Tensor):
if value.numel() > 10: # For large tensors
shape_str = 'x'.join(str(x) for x in value.shape)
return f"<Tensor shape={shape_str}>"
else:
return str(value.tolist())
return value
class StringFormat(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
optional = {f"value{i}": (IO.ANY, {"default": "", "forceInput": True}) for i in range(5)}
optional["format"] = (IO.STRING, {"default": "{}", "multiline": True})
return {
"required": {},
"optional": optional
}
RETURN_TYPES = ("STRING",)
CATEGORY = "strings"
FUNCTION = "execute"
def execute(self, format: str = "{}", *args: Any, **kwargs) -> ValidatedNodeResult:
for k, v in kwargs.items():
kwargs[k] = format_value(v)
try:
if any(f'{{{k}' in format for k in kwargs.keys()):
return (format.format(**kwargs),)
else:
return (format.format(*[kwargs[k] for k in natsorted(kwargs.keys())]),)
except (IndexError, KeyError) as e:
return ("Format error: " + str(e),)
class ToString(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": (IO.ANY, {}),
}
}
CATEGORY = "strings"
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, value: Any):
return str(value),
export_custom_nodes()

View File

@ -0,0 +1,94 @@
import re
import pytest
from comfy_extras.nodes.nodes_regexp import RegexFlags, Regex, RegexMatchGroupByIndex, RegexMatchGroupByName, \
RegexMatchExpand
def test_regex_flags():
n = RegexFlags()
# Test with no flags
flags, = n.execute(ASCII=False, IGNORECASE=False, LOCALE=False, MULTILINE=False,
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=True)
assert flags == 0
# Test single flag
flags, = n.execute(ASCII=True, IGNORECASE=False, LOCALE=False, MULTILINE=False,
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False)
assert flags == re.ASCII
# Test multiple flags
flags, = n.execute(ASCII=True, IGNORECASE=True, LOCALE=False, MULTILINE=False,
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False)
assert flags == (re.ASCII | re.IGNORECASE)
def test_regex():
n = Regex()
# Basic match test
match, = n.execute(pattern=r"hello", string="hello world")
assert match is not None
assert match.group(0) == "hello"
# Test with flags
match, = n.execute(pattern=r"HELLO", string="hello world", flags=re.IGNORECASE)
assert match is not None
assert match.group(0) == "hello"
# Test no match
match, = n.execute(pattern=r"python", string="hello world")
assert match is None
def test_regex_match_group_by_index():
n = RegexMatchGroupByIndex()
regex = Regex()
# Test basic group
match, = regex.execute(pattern=r"(hello) (world)", string="hello world")
group, = n.execute(match=match, index=0)
assert group == "hello world"
group, = n.execute(match=match, index=1)
assert group == "hello"
group, = n.execute(match=match, index=2)
assert group == "world"
def test_regex_match_group_by_name():
n = RegexMatchGroupByName()
regex = Regex()
# Test named group
match, = regex.execute(pattern=r"(?P<greeting>hello) (?P<subject>world)",
string="hello world")
group, = n.execute(match=match, name="greeting")
assert group == "hello"
group, = n.execute(match=match, name="subject")
assert group == "world"
# Test non-existent group name
with pytest.raises(IndexError):
n.execute(match=match, name="nonexistent")
def test_regex_match_expand():
n = RegexMatchExpand()
regex = Regex()
# Test basic expansion
match, = regex.execute(pattern=r"(hello) (world)", string="hello world")
result, = n.execute(match=match, template=r"\2, \1!")
assert result == "world, hello!"
# Test named group expansion
match, = regex.execute(pattern=r"(?P<greeting>hello) (?P<subject>world)",
string="hello world")
result, = n.execute(match=match, template=r"\g<subject>, \g<greeting>!")
assert result == "world, hello!"

View File

@ -0,0 +1,93 @@
import torch
from comfy_extras.nodes.nodes_strings import StringFormat
def test_string_format_basic():
n = StringFormat()
# Test basic string formatting
result, = n.execute(format="Hello, {}!", value0="World")
assert result == "Hello, World!"
# Test multiple values
result, = n.execute(format="{} plus {} equals {}", value0=2, value1=2, value2=4)
assert result == "2 plus 2 equals 4"
def test_string_format_types():
n = StringFormat()
# Test with different types
result, = n.execute(format="Float: {:.2f}, Int: {}, Bool: {}",
value0=3.14159, value1=42, value2=True)
assert result == "Float: 3.14, Int: 42, Bool: True"
# Test None values
result, = n.execute(format="{}, {}, {}", value0=None, value1="test", value2=None)
assert result == "None, test, None"
def test_string_format_tensors():
n = StringFormat()
# Test small tensor
small_tensor = torch.tensor([1, 2, 3])
result, = n.execute(format="Tensor: {}", value0=small_tensor)
assert result == "Tensor: [1, 2, 3]"
# Test large tensor
large_tensor = torch.randn(100, 100)
result, = n.execute(format="Large tensor: {}", value0=large_tensor)
assert result == "Large tensor: <Tensor shape=100x100>"
# Test mixed tensor sizes
small_tensor = torch.tensor([1, 2])
large_tensor = torch.randn(50, 50)
result, = n.execute(format="{} and {}", value0=small_tensor, value1=large_tensor)
assert "and <Tensor shape=50x50>" in result
assert "[1, 2]" in result
def test_string_format_edge_cases():
n = StringFormat()
# Test with missing values
result, = n.execute(format="{} {} {}", value0="a", value1="b")
assert result.startswith("Format error: ")
# Test with empty format string
result, = n.execute(format="", value0="ignored")
assert result == ""
# Test with no placeholders
result, = n.execute(format="Hello World", value0="ignored")
assert result == "Hello World"
# Test with named placeholders
result, = n.execute(format="X: {value0}, Y: {value1}", value0=10, value1=20)
assert result == "X: 10, Y: 20"
# Test mixing None, tensors and regular values
tensor = torch.tensor([1, 2, 3])
result, = n.execute(format="{}, {}, {}", value0=None, value1=tensor, value2="test")
assert result == "None, [1, 2, 3], test"
def test_string_format_tensor_edge_cases():
n = StringFormat()
# Test empty tensor
empty_tensor = torch.tensor([])
result, = n.execute(format="Empty tensor: {}", value0=empty_tensor)
assert result == "Empty tensor: []"
# Test scalar tensor
scalar_tensor = torch.tensor(5)
result, = n.execute(format="Scalar tensor: {}", value0=scalar_tensor)
assert result == "Scalar tensor: 5"
# Test multi-dimensional small tensor
small_2d_tensor = torch.tensor([[1, 2], [3, 4]])
result, = n.execute(format="2D tensor: {}", value0=small_2d_tensor)
assert result == "2D tensor: [[1, 2], [3, 4]]"