mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Quality of life improvements
- export_custom_nodes() finds all the classes that inherit from CustomNode and exports them correctly for custom node discovery to find - regular expressions - additional string formatting and parsing nodes
This commit is contained in:
parent
cf08b11132
commit
31b6b53236
@ -1,10 +1,14 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from .cli_args import args
|
|
||||||
|
|
||||||
from PIL import ImageFile, UnidentifiedImageError
|
from PIL import ImageFile, UnidentifiedImageError
|
||||||
|
|
||||||
def conditioning_set_values(conditioning, values={}):
|
from .cli_args import args
|
||||||
|
from .nodes.package_typing import CustomNode
|
||||||
|
|
||||||
|
|
||||||
|
def conditioning_set_values(conditioning, values: dict = None):
|
||||||
|
if values is None:
|
||||||
|
values = {}
|
||||||
c = []
|
c = []
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
n = [t[0], t[1].copy()]
|
n = [t[0], t[1].copy()]
|
||||||
@ -14,11 +18,12 @@ def conditioning_set_values(conditioning, values={}):
|
|||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
def pillow(fn, arg):
|
def pillow(fn, arg):
|
||||||
prev_value = None
|
prev_value = None
|
||||||
try:
|
try:
|
||||||
x = fn(arg)
|
x = fn(arg)
|
||||||
except (OSError, UnidentifiedImageError, ValueError): #PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
|
except (OSError, UnidentifiedImageError, ValueError): # PIL issues #4472 and #2445, also fixes ComfyUI issue #3416
|
||||||
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
x = fn(arg)
|
x = fn(arg)
|
||||||
@ -27,6 +32,7 @@ def pillow(fn, arg):
|
|||||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def hasher():
|
def hasher():
|
||||||
hashfuncs = {
|
hashfuncs = {
|
||||||
"md5": hashlib.md5,
|
"md5": hashlib.md5,
|
||||||
@ -35,3 +41,35 @@ def hasher():
|
|||||||
"sha512": hashlib.sha512
|
"sha512": hashlib.sha512
|
||||||
}
|
}
|
||||||
return hashfuncs[args.default_hashing_function]
|
return hashfuncs[args.default_hashing_function]
|
||||||
|
|
||||||
|
|
||||||
|
def export_custom_nodes():
|
||||||
|
"""
|
||||||
|
Finds all classes in the current module that extend CustomNode and creates
|
||||||
|
a NODE_CLASS_MAPPINGS dictionary mapping class names to class objects.
|
||||||
|
Must be called from within the module where the CustomNode classes are defined.
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
# Get the calling module
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
try:
|
||||||
|
module = inspect.getmodule(frame.f_back)
|
||||||
|
|
||||||
|
custom_nodes = {}
|
||||||
|
for name, obj in inspect.getmembers(module):
|
||||||
|
if (inspect.isclass(obj) and
|
||||||
|
CustomNode in obj.__mro__ and
|
||||||
|
obj != CustomNode):
|
||||||
|
custom_nodes[name] = obj
|
||||||
|
if hasattr(module, 'NODE_CLASS_MAPPINGS'):
|
||||||
|
node_class_mappings: dict = getattr(module, 'NODE_CLASS_MAPPINGS')
|
||||||
|
node_class_mappings.update(custom_nodes)
|
||||||
|
else:
|
||||||
|
setattr(module, 'NODE_CLASS_MAPPINGS', custom_nodes)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up circular reference
|
||||||
|
del frame
|
||||||
|
|
||||||
|
return custom_nodes
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
|
||||||
Callable, List, Type
|
Callable, List, Type, runtime_checkable
|
||||||
|
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
|
|||||||
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
|
||||||
|
|
||||||
# numpy seeds must be between 0 and 2**32 - 1
|
# numpy seeds must be between 0 and 2**32 - 1
|
||||||
Seed = ("INT", {"default": 0, "min": 0, "max": 2**32 - 1})
|
Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1})
|
||||||
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff})
|
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff})
|
||||||
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
|
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from operator import add, mul, pow
|
from operator import add, mul, pow
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from comfy.comfy_types import IO
|
||||||
|
from comfy.node_helpers import export_custom_nodes
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
|
|
||||||
|
|
||||||
@ -527,32 +530,38 @@ class FloatToInt(CustomNode):
|
|||||||
return int(value),
|
return int(value),
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {}
|
class StringToInt(CustomNode):
|
||||||
for cls in (
|
@classmethod
|
||||||
FloatAdd,
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
FloatSubtract,
|
return {
|
||||||
FloatMultiply,
|
"required": {
|
||||||
FloatDivide,
|
"value": ("STRING", {}),
|
||||||
FloatPower,
|
}
|
||||||
FloatMin,
|
}
|
||||||
FloatMax,
|
|
||||||
FloatAbs,
|
CATEGORY = "arithmetic"
|
||||||
FloatAverage,
|
RETURN_TYPES = ("INT",)
|
||||||
FloatLerp,
|
FUNCTION = "execute"
|
||||||
FloatInverseLerp,
|
|
||||||
FloatClamp,
|
def execute(self, value: str = "0"):
|
||||||
IntAdd,
|
return int(value),
|
||||||
IntSubtract,
|
|
||||||
IntMultiply,
|
|
||||||
IntDivide,
|
class StringToFloat(CustomNode):
|
||||||
IntMod,
|
@classmethod
|
||||||
IntPower,
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
IntMin,
|
return {
|
||||||
IntMax,
|
"required": {
|
||||||
IntAbs,
|
"value": ("STRING", {}),
|
||||||
IntAverage,
|
}
|
||||||
IntLerp,
|
}
|
||||||
IntInverseLerp,
|
|
||||||
IntClamp,
|
CATEGORY = "arithmetic"
|
||||||
):
|
RETURN_TYPES = ("FLOAT",)
|
||||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, value: str = "0"):
|
||||||
|
return float(value),
|
||||||
|
|
||||||
|
|
||||||
|
export_custom_nodes()
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWA
|
|||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||||
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
from comfy.model_management import get_torch_device_name, unet_dtype, unet_offload_device
|
||||||
|
from comfy.node_helpers import export_custom_nodes
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult, Seed
|
||||||
|
|
||||||
_AUTO_CHAT_TEMPLATE = "default"
|
_AUTO_CHAT_TEMPLATE = "default"
|
||||||
@ -370,7 +371,7 @@ class PreviewString(CustomNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "language"
|
CATEGORY = "strings"
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
RETURN_TYPES = ("STRING",)
|
RETURN_TYPES = ("STRING",)
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
@ -392,7 +393,7 @@ class SaveString(CustomNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "language"
|
CATEGORY = "strings"
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
RETURN_TYPES = ()
|
RETURN_TYPES = ()
|
||||||
@ -412,24 +413,4 @@ class SaveString(CustomNode):
|
|||||||
return {"ui": {"string": value}}
|
return {"ui": {"string": value}}
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {}
|
export_custom_nodes()
|
||||||
for cls in (
|
|
||||||
TransformerTopKSampler,
|
|
||||||
TransformerTopPSampler,
|
|
||||||
TransformerTemperatureSampler,
|
|
||||||
TransformerGreedySampler,
|
|
||||||
TransformerContrastiveSearchSampler,
|
|
||||||
TransformerBeamSearchSampler,
|
|
||||||
TransformerMergeSamplers,
|
|
||||||
TransformersLoader,
|
|
||||||
TransformersImageProcessorLoader,
|
|
||||||
TransformersGenerate,
|
|
||||||
OneShotInstructTokenize,
|
|
||||||
TransformersM2M100LanguageCodes,
|
|
||||||
TransformersTokenize,
|
|
||||||
TransformersFlores200LanguageCodes,
|
|
||||||
TransformersTranslationTokenize,
|
|
||||||
PreviewString,
|
|
||||||
SaveString,
|
|
||||||
):
|
|
||||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
|
||||||
|
|||||||
107
comfy_extras/nodes/nodes_regexp.py
Normal file
107
comfy_extras/nodes/nodes_regexp.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from comfy.node_helpers import export_custom_nodes
|
||||||
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
|
|
||||||
|
MATCH_TYPE_NAME = "MATCH"
|
||||||
|
|
||||||
|
|
||||||
|
class RegexFlags(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
flags_ = {flag_name: ("BOOLEAN", {}) for flag_name in ["ASCII", "IGNORECASE", "LOCALE", "MULTILINE", "DOTALL", "VERBOSE",
|
||||||
|
"UNICODE", "NOFLAG"]}
|
||||||
|
return {
|
||||||
|
"required": flags_,
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "regular_expressions"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
RETURN_TYPES = ("INT",)
|
||||||
|
|
||||||
|
def execute(self, **kwargs) -> tuple[int]:
|
||||||
|
flags = re.RegexFlag.NOFLAG
|
||||||
|
|
||||||
|
for name, on in kwargs.items():
|
||||||
|
if on:
|
||||||
|
flags |= re.RegexFlag[name]
|
||||||
|
|
||||||
|
return int(flags),
|
||||||
|
|
||||||
|
|
||||||
|
class Regex(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"pattern": ("STRING", {}),
|
||||||
|
"string": ("STRING", {}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"flags": ("INT", {"min": 0, "default": 0})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "regular_expressions"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||||
|
|
||||||
|
def execute(self, pattern: str = "", string: str = "", flags: int = 0) -> tuple[re.Match]:
|
||||||
|
return re.match(pattern=pattern, string=string, flags=flags),
|
||||||
|
|
||||||
|
|
||||||
|
class RegexMatchGroupByIndex(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"match": (MATCH_TYPE_NAME, {}),
|
||||||
|
"index": ("INT", {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "regular_expressions"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||||
|
|
||||||
|
def execute(self, match: re.Match, index: int = 0) -> tuple[str]:
|
||||||
|
return match.group(index),
|
||||||
|
|
||||||
|
|
||||||
|
class RegexMatchGroupByName(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"match": (MATCH_TYPE_NAME, {}),
|
||||||
|
"name": ("STRING", {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "regular_expressions"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||||
|
|
||||||
|
def execute(self, match: re.Match, name: str = "") -> tuple[str]:
|
||||||
|
return match.group(name),
|
||||||
|
|
||||||
|
|
||||||
|
class RegexMatchExpand(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"match": (MATCH_TYPE_NAME, {}),
|
||||||
|
"template": ("STRING", {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "regular_expressions"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
RETURN_TYPES = (MATCH_TYPE_NAME,)
|
||||||
|
|
||||||
|
def execute(self, match: re.Match, template: str = "") -> tuple[str]:
|
||||||
|
return match.expand(template),
|
||||||
|
|
||||||
|
|
||||||
|
export_custom_nodes()
|
||||||
67
comfy_extras/nodes/nodes_strings.py
Normal file
67
comfy_extras/nodes/nodes_strings.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from natsort import natsorted
|
||||||
|
|
||||||
|
from comfy.comfy_types import IO
|
||||||
|
from comfy.node_helpers import export_custom_nodes
|
||||||
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||||
|
|
||||||
|
|
||||||
|
def format_value(value: Any) -> str | Any:
|
||||||
|
"""Helper function to format values for string formatting."""
|
||||||
|
if value is None:
|
||||||
|
return "None"
|
||||||
|
elif isinstance(value, torch.Tensor):
|
||||||
|
if value.numel() > 10: # For large tensors
|
||||||
|
shape_str = 'x'.join(str(x) for x in value.shape)
|
||||||
|
return f"<Tensor shape={shape_str}>"
|
||||||
|
else:
|
||||||
|
return str(value.tolist())
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class StringFormat(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
optional = {f"value{i}": (IO.ANY, {"default": "", "forceInput": True}) for i in range(5)}
|
||||||
|
optional["format"] = (IO.STRING, {"default": "{}", "multiline": True})
|
||||||
|
return {
|
||||||
|
"required": {},
|
||||||
|
"optional": optional
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
CATEGORY = "strings"
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, format: str = "{}", *args: Any, **kwargs) -> ValidatedNodeResult:
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
kwargs[k] = format_value(v)
|
||||||
|
try:
|
||||||
|
if any(f'{{{k}' in format for k in kwargs.keys()):
|
||||||
|
return (format.format(**kwargs),)
|
||||||
|
else:
|
||||||
|
return (format.format(*[kwargs[k] for k in natsorted(kwargs.keys())]),)
|
||||||
|
except (IndexError, KeyError) as e:
|
||||||
|
return ("Format error: " + str(e),)
|
||||||
|
|
||||||
|
|
||||||
|
class ToString(CustomNode):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": (IO.ANY, {}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CATEGORY = "strings"
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
|
||||||
|
def execute(self, value: Any):
|
||||||
|
return str(value),
|
||||||
|
|
||||||
|
|
||||||
|
export_custom_nodes()
|
||||||
94
tests/unit/test_regular_expression_nodes.py
Normal file
94
tests/unit/test_regular_expression_nodes.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy_extras.nodes.nodes_regexp import RegexFlags, Regex, RegexMatchGroupByIndex, RegexMatchGroupByName, \
|
||||||
|
RegexMatchExpand
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex_flags():
|
||||||
|
n = RegexFlags()
|
||||||
|
|
||||||
|
# Test with no flags
|
||||||
|
flags, = n.execute(ASCII=False, IGNORECASE=False, LOCALE=False, MULTILINE=False,
|
||||||
|
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=True)
|
||||||
|
assert flags == 0
|
||||||
|
|
||||||
|
# Test single flag
|
||||||
|
flags, = n.execute(ASCII=True, IGNORECASE=False, LOCALE=False, MULTILINE=False,
|
||||||
|
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False)
|
||||||
|
assert flags == re.ASCII
|
||||||
|
|
||||||
|
# Test multiple flags
|
||||||
|
flags, = n.execute(ASCII=True, IGNORECASE=True, LOCALE=False, MULTILINE=False,
|
||||||
|
DOTALL=False, VERBOSE=False, UNICODE=False, NOFLAG=False)
|
||||||
|
assert flags == (re.ASCII | re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex():
|
||||||
|
n = Regex()
|
||||||
|
|
||||||
|
# Basic match test
|
||||||
|
match, = n.execute(pattern=r"hello", string="hello world")
|
||||||
|
assert match is not None
|
||||||
|
assert match.group(0) == "hello"
|
||||||
|
|
||||||
|
# Test with flags
|
||||||
|
match, = n.execute(pattern=r"HELLO", string="hello world", flags=re.IGNORECASE)
|
||||||
|
assert match is not None
|
||||||
|
assert match.group(0) == "hello"
|
||||||
|
|
||||||
|
# Test no match
|
||||||
|
match, = n.execute(pattern=r"python", string="hello world")
|
||||||
|
assert match is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex_match_group_by_index():
|
||||||
|
n = RegexMatchGroupByIndex()
|
||||||
|
regex = Regex()
|
||||||
|
|
||||||
|
# Test basic group
|
||||||
|
match, = regex.execute(pattern=r"(hello) (world)", string="hello world")
|
||||||
|
group, = n.execute(match=match, index=0)
|
||||||
|
assert group == "hello world"
|
||||||
|
|
||||||
|
group, = n.execute(match=match, index=1)
|
||||||
|
assert group == "hello"
|
||||||
|
|
||||||
|
group, = n.execute(match=match, index=2)
|
||||||
|
assert group == "world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex_match_group_by_name():
|
||||||
|
n = RegexMatchGroupByName()
|
||||||
|
regex = Regex()
|
||||||
|
|
||||||
|
# Test named group
|
||||||
|
match, = regex.execute(pattern=r"(?P<greeting>hello) (?P<subject>world)",
|
||||||
|
string="hello world")
|
||||||
|
|
||||||
|
group, = n.execute(match=match, name="greeting")
|
||||||
|
assert group == "hello"
|
||||||
|
|
||||||
|
group, = n.execute(match=match, name="subject")
|
||||||
|
assert group == "world"
|
||||||
|
|
||||||
|
# Test non-existent group name
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
n.execute(match=match, name="nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex_match_expand():
|
||||||
|
n = RegexMatchExpand()
|
||||||
|
regex = Regex()
|
||||||
|
|
||||||
|
# Test basic expansion
|
||||||
|
match, = regex.execute(pattern=r"(hello) (world)", string="hello world")
|
||||||
|
result, = n.execute(match=match, template=r"\2, \1!")
|
||||||
|
assert result == "world, hello!"
|
||||||
|
|
||||||
|
# Test named group expansion
|
||||||
|
match, = regex.execute(pattern=r"(?P<greeting>hello) (?P<subject>world)",
|
||||||
|
string="hello world")
|
||||||
|
result, = n.execute(match=match, template=r"\g<subject>, \g<greeting>!")
|
||||||
|
assert result == "world, hello!"
|
||||||
93
tests/unit/test_string_nodes.py
Normal file
93
tests/unit/test_string_nodes.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy_extras.nodes.nodes_strings import StringFormat
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_format_basic():
|
||||||
|
n = StringFormat()
|
||||||
|
|
||||||
|
# Test basic string formatting
|
||||||
|
result, = n.execute(format="Hello, {}!", value0="World")
|
||||||
|
assert result == "Hello, World!"
|
||||||
|
|
||||||
|
# Test multiple values
|
||||||
|
result, = n.execute(format="{} plus {} equals {}", value0=2, value1=2, value2=4)
|
||||||
|
assert result == "2 plus 2 equals 4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_format_types():
|
||||||
|
n = StringFormat()
|
||||||
|
|
||||||
|
# Test with different types
|
||||||
|
result, = n.execute(format="Float: {:.2f}, Int: {}, Bool: {}",
|
||||||
|
value0=3.14159, value1=42, value2=True)
|
||||||
|
assert result == "Float: 3.14, Int: 42, Bool: True"
|
||||||
|
|
||||||
|
# Test None values
|
||||||
|
result, = n.execute(format="{}, {}, {}", value0=None, value1="test", value2=None)
|
||||||
|
assert result == "None, test, None"
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_format_tensors():
|
||||||
|
n = StringFormat()
|
||||||
|
|
||||||
|
# Test small tensor
|
||||||
|
small_tensor = torch.tensor([1, 2, 3])
|
||||||
|
result, = n.execute(format="Tensor: {}", value0=small_tensor)
|
||||||
|
assert result == "Tensor: [1, 2, 3]"
|
||||||
|
|
||||||
|
# Test large tensor
|
||||||
|
large_tensor = torch.randn(100, 100)
|
||||||
|
result, = n.execute(format="Large tensor: {}", value0=large_tensor)
|
||||||
|
assert result == "Large tensor: <Tensor shape=100x100>"
|
||||||
|
|
||||||
|
# Test mixed tensor sizes
|
||||||
|
small_tensor = torch.tensor([1, 2])
|
||||||
|
large_tensor = torch.randn(50, 50)
|
||||||
|
result, = n.execute(format="{} and {}", value0=small_tensor, value1=large_tensor)
|
||||||
|
assert "and <Tensor shape=50x50>" in result
|
||||||
|
assert "[1, 2]" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_format_edge_cases():
|
||||||
|
n = StringFormat()
|
||||||
|
|
||||||
|
# Test with missing values
|
||||||
|
result, = n.execute(format="{} {} {}", value0="a", value1="b")
|
||||||
|
assert result.startswith("Format error: ")
|
||||||
|
|
||||||
|
# Test with empty format string
|
||||||
|
result, = n.execute(format="", value0="ignored")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
# Test with no placeholders
|
||||||
|
result, = n.execute(format="Hello World", value0="ignored")
|
||||||
|
assert result == "Hello World"
|
||||||
|
|
||||||
|
# Test with named placeholders
|
||||||
|
result, = n.execute(format="X: {value0}, Y: {value1}", value0=10, value1=20)
|
||||||
|
assert result == "X: 10, Y: 20"
|
||||||
|
|
||||||
|
# Test mixing None, tensors and regular values
|
||||||
|
tensor = torch.tensor([1, 2, 3])
|
||||||
|
result, = n.execute(format="{}, {}, {}", value0=None, value1=tensor, value2="test")
|
||||||
|
assert result == "None, [1, 2, 3], test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_format_tensor_edge_cases():
|
||||||
|
n = StringFormat()
|
||||||
|
|
||||||
|
# Test empty tensor
|
||||||
|
empty_tensor = torch.tensor([])
|
||||||
|
result, = n.execute(format="Empty tensor: {}", value0=empty_tensor)
|
||||||
|
assert result == "Empty tensor: []"
|
||||||
|
|
||||||
|
# Test scalar tensor
|
||||||
|
scalar_tensor = torch.tensor(5)
|
||||||
|
result, = n.execute(format="Scalar tensor: {}", value0=scalar_tensor)
|
||||||
|
assert result == "Scalar tensor: 5"
|
||||||
|
|
||||||
|
# Test multi-dimensional small tensor
|
||||||
|
small_2d_tensor = torch.tensor([[1, 2], [3, 4]])
|
||||||
|
result, = n.execute(format="2D tensor: {}", value0=small_2d_tensor)
|
||||||
|
assert result == "2D tensor: [[1, 2], [3, 4]]"
|
||||||
Loading…
Reference in New Issue
Block a user