Generators

This commit is contained in:
doctorpangloss 2025-02-20 15:00:07 -08:00
parent 0cfde0ad6d
commit 42f75cadc0
3 changed files with 412 additions and 22 deletions

View File

@ -0,0 +1,124 @@
from decimal import Decimal
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes
class IntRange(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"start": ("INT", {"default": 0, "step": 0.001}),
"end": ("INT", {"default": 1, "step": 0.001}),
"step": ("INT", {"default": 1}),
}
}
CATEGORY = "generators"
OUTPUT_IS_LIST = (True,)
RETURN_TYPES = ("INT",)
FUNCTION = "execute"
def execute(self, start: int, end: int, step: int) -> tuple[list[int]]:
return list(range(start, end, step)),
class FloatRange1(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"start": ("FLOAT", {"default": 0.0, "step": 0.001}),
"end": ("FLOAT", {"default": 1.0, "step": 0.001}),
"step": ("FLOAT", {"default": 1}),
}
}
CATEGORY = "generators"
OUTPUT_IS_LIST = (True,)
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, start: float, end: float, step: float) -> tuple[list[float]]:
if step == 0:
return [],
steps = (Decimal(str(end)) - Decimal(str(start))) / Decimal(str(step))
return [start + i * step for i in range(int(steps))],
class FloatRange2(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"start": ("FLOAT", {"default": 0.0, "step": 0.001}),
"end": ("FLOAT", {"default": 1.0, "step": 0.001}),
"fence_posts": ("INT", {"default": 2, "min": 0}),
}
}
CATEGORY = "generators"
OUTPUT_IS_LIST = (True,)
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, start: float, end: float, fence_posts: int) -> tuple[list[float]]:
if fence_posts == 0:
return [],
elif fence_posts == 1:
return [start],
elif fence_posts == 2:
return [start, end],
step = (end - start) / (fence_posts - 1)
return [start + i * step for i in range(fence_posts)],
class FloatRange3(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"start": ("FLOAT", {"default": 0.0, "step": 0.001}),
"end": ("FLOAT", {"default": 1.0, "step": 0.001}),
"spans": ("INT", {"default": 1, "min": 0}),
}
}
CATEGORY = "generators"
OUTPUT_IS_LIST = (True,)
RETURN_TYPES = ("FLOAT",)
FUNCTION = "execute"
def execute(self, start: float, end: float, spans: int) -> tuple[list[float]]:
if spans == 0:
return [],
elif spans == 1:
return [start],
span_width = (end - start) / spans
return [start + i * span_width for i in range(spans)],
class StringSplit(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {}),
"delimiter": ("STRING", {"default": ","}),
}
}
CATEGORY = "generators"
OUTPUT_IS_LIST = (True,)
RETURN_TYPES = ("STRING",)
FUNCTION = "execute"
def execute(self, value: str = "", delimiter: str = ",") -> tuple[list[str]]:
return value.split(delimiter),
export_custom_nodes()

View File

@ -29,7 +29,9 @@ from natsort import natsorted
from torch import Tensor
from comfy.cmd import folder_paths
from comfy.comfy_types import IO
from comfy.digest import digest
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
InputTypeSpec, ValidatedNodeResult
@ -301,6 +303,25 @@ class StringJoin(CustomNode):
return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
class StringJoin1(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
optional = {f"value{i}": (IO.ANY, {}) for i in range(5)}
optional["separator"] = (IO.STRING, {"default": "_"})
return {
"required": {},
"optional": optional
}
RETURN_TYPES = ("STRING",)
CATEGORY = "api/openapi"
FUNCTION = "execute"
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
sorted_keys = natsorted(kwargs.keys())
return (separator.join([str(kwargs[key]) for key in sorted_keys if kwargs[key] is not None]),)
class StringToUri(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
@ -711,25 +732,4 @@ class ImageRequestParameter(CustomNode):
return (output_image,)
NODE_CLASS_MAPPINGS = {}
for cls in (
IntRequestParameter,
FloatRequestParameter,
StringRequestParameter,
StringEnumRequestParameter,
BooleanRequestParameter,
HashImage,
StringPosixPathJoin,
LegacyOutputURIs,
DevNullUris,
StringJoin,
StringToUri,
UriFormat,
ImageExif,
ImageExifMerge,
ImageExifUncommon,
ImageExifCreationDateAndBatchNumber,
SaveImagesResponse,
ImageRequestParameter
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls
export_custom_nodes()

View File

@ -0,0 +1,266 @@
import pytest
from comfy_extras.nodes.nodes_generators import IntRange, FloatRange1, FloatRange2, FloatRange3, StringSplit
def test_int_range():
node = IntRange()
# Basic range
result = node.execute(0, 5, 1)
assert result == ([0, 1, 2, 3, 4],)
# Negative step
result = node.execute(5, 0, -1)
assert result == ([5, 4, 3, 2, 1],)
# Empty range (step in wrong direction)
result = node.execute(0, 5, -1)
assert result == ([],)
# Single item range
result = node.execute(0, 1, 1)
assert result == ([0],)
# Step size > 1
result = node.execute(0, 10, 2)
assert result == ([0, 2, 4, 6, 8],)
def test_float_range1():
node = FloatRange1()
# Basic range
result = node.execute(0.0, 1.0, 0.25)
assert result == ([0.0, 0.25, 0.5, 0.75],)
# Step size of 1
result = node.execute(0.0, 3.0, 1.0)
assert result == ([0.0, 1.0, 2.0],)
# Zero step
result = node.execute(0.0, 1.0, 0.0)
assert result == ([],)
# Negative step
result = node.execute(1.0, 0.0, -0.25)
assert result == ([1.0, 0.75, 0.5, 0.25],)
# Test floating point precision
result = node.execute(0.0, 0.3, 0.1)
assert len(result[0]) == 3
assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2]))
def test_float_range2():
node = FloatRange2()
# Basic range with 5 points
result = node.execute(0.0, 1.0, 5)
assert result == ([0.0, 0.25, 0.5, 0.75, 1.0],)
# Zero points
result = node.execute(0.0, 1.0, 0)
assert result == ([],)
# One point
result = node.execute(0.0, 1.0, 1)
assert result == ([0.0],)
# Two points
result = node.execute(0.0, 1.0, 2)
assert result == ([0.0, 1.0],)
# Test negative range
result = node.execute(1.0, -1.0, 3)
assert result == ([1.0, 0.0, -1.0],)
# Test floating point precision
result = node.execute(0.0, 0.2, 3)
assert len(result[0]) == 3
assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2]))
def test_float_range3():
node = FloatRange3()
# Basic range with 4 spans
result = node.execute(0.0, 1.0, 4)
expected = [0.0, 0.25, 0.5, 0.75] # Note: doesn't include end point
assert result == (expected,)
# Zero spans
result = node.execute(0.0, 1.0, 0)
assert result == ([],)
# One span
result = node.execute(0.0, 1.0, 1)
assert result == ([0.0],)
# Test negative range
result = node.execute(1.0, -1.0, 2)
assert result == ([1.0, 0.0],)
# Test floating point precision
result = node.execute(0.0, 0.3, 3)
assert len(result[0]) == 3
assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2]))
def test_output_types():
"""Test that all nodes return correct types and list outputs"""
# IntRange
result = IntRange().execute(0, 5, 1)
assert isinstance(result, tuple)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, int) for x in result[0])
# FloatRange1
result = FloatRange1().execute(0.0, 1.0, 0.5)
assert isinstance(result, tuple)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])
# FloatRange2
result = FloatRange2().execute(0.0, 1.0, 3)
assert isinstance(result, tuple)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])
# FloatRange3
result = FloatRange3().execute(0.0, 1.0, 2)
assert isinstance(result, tuple)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, float) for x in result[0])
@pytest.mark.parametrize("node_class", [IntRange, FloatRange1, FloatRange2, FloatRange3])
def test_node_metadata(node_class):
"""Test that all nodes have correct metadata"""
node = node_class()
# Check required class attributes
assert hasattr(node, 'CATEGORY')
assert hasattr(node, 'OUTPUT_IS_LIST')
assert hasattr(node, 'RETURN_TYPES')
assert hasattr(node, 'FUNCTION')
# Verify OUTPUT_IS_LIST matches number of outputs
assert len(node.OUTPUT_IS_LIST) == len(node.RETURN_TYPES)
assert node.OUTPUT_IS_LIST == (True,)
# Verify input types exist
input_types = node.INPUT_TYPES()
assert 'required' in input_types
# All nodes should have appropriate numeric types
if node_class == IntRange:
assert all(v[0] == 'INT' for v in input_types['required'].values())
else:
assert input_types['required']['start'][0] == 'FLOAT'
assert input_types['required']['end'][0] == 'FLOAT'
def test_string_split_basic():
node = StringSplit()
# Basic comma split
result = node.execute("a,b,c")
assert result == (["a", "b", "c"],)
# Custom delimiter
result = node.execute("a|b|c", delimiter="|")
assert result == (["a", "b", "c"],)
# Empty string
result = node.execute("")
assert result == ([""],)
# Single value (no delimiters)
result = node.execute("abc")
assert result == (["abc"],)
def test_string_split_edge_cases():
node = StringSplit()
# Multiple consecutive delimiters
result = node.execute("a,,b,,c")
assert result == (["a", "", "b", "", "c"],)
# Leading/trailing delimiters
result = node.execute(",a,b,c,")
assert result == (["", "a", "b", "c", ""],)
# Multi-character delimiter
result = node.execute("a<->b<->c", delimiter="<->")
assert result == (["a", "b", "c"],)
# Whitespace handling
result = node.execute(" a , b , c ")
assert result == ([" a ", " b ", " c "],)
# Split on whitespace
result = node.execute("a b c", delimiter=" ")
assert result == (["a", "b", "c"],)
def test_string_split_special_chars():
node = StringSplit()
# Split on newline
result = node.execute("a\nb\nc", delimiter="\n")
assert result == (["a", "b", "c"],)
# Split on tab
result = node.execute("a\tb\tc", delimiter="\t")
assert result == (["a", "b", "c"],)
# Regex special characters as delimiters
result = node.execute("a.b.c", delimiter=".")
assert result == (["a", "b", "c"],)
# Unicode delimiters
result = node.execute("a→b→c", delimiter="")
assert result == (["a", "b", "c"],)
def test_string_split_metadata():
node = StringSplit()
# Check required class attributes
assert hasattr(node, 'CATEGORY')
assert hasattr(node, 'OUTPUT_IS_LIST')
assert hasattr(node, 'RETURN_TYPES')
assert hasattr(node, 'FUNCTION')
# Verify OUTPUT_IS_LIST matches RETURN_TYPES
assert len(node.OUTPUT_IS_LIST) == len(node.RETURN_TYPES)
assert node.OUTPUT_IS_LIST == (True,)
assert node.RETURN_TYPES == ("STRING",)
# Check input types
input_types = node.INPUT_TYPES()
assert 'required' in input_types
assert 'value' in input_types['required']
assert 'delimiter' in input_types['required']
assert input_types['required']['value'][0] == 'STRING'
assert input_types['required']['delimiter'][0] == 'STRING'
# Check delimiter default
assert input_types['required']['delimiter'][1]['default'] == ','
def test_string_split_return_type():
node = StringSplit()
# Verify return structure
result = node.execute("a,b,c")
assert isinstance(result, tuple)
assert len(result) == 1
assert isinstance(result[0], list)
assert all(isinstance(x, str) for x in result[0])