diff --git a/comfy_extras/nodes/nodes_generators.py b/comfy_extras/nodes/nodes_generators.py new file mode 100644 index 000000000..1b538b076 --- /dev/null +++ b/comfy_extras/nodes/nodes_generators.py @@ -0,0 +1,124 @@ +from decimal import Decimal + +from comfy.node_helpers import export_custom_nodes +from comfy.nodes.package_typing import CustomNode, InputTypes + + +class IntRange(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "start": ("INT", {"default": 0, "step": 0.001}), + "end": ("INT", {"default": 1, "step": 0.001}), + "step": ("INT", {"default": 1}), + } + } + + CATEGORY = "generators" + OUTPUT_IS_LIST = (True,) + RETURN_TYPES = ("INT",) + FUNCTION = "execute" + + def execute(self, start: int, end: int, step: int) -> tuple[list[int]]: + return list(range(start, end, step)), + + +class FloatRange1(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "start": ("FLOAT", {"default": 0.0, "step": 0.001}), + "end": ("FLOAT", {"default": 1.0, "step": 0.001}), + "step": ("FLOAT", {"default": 1}), + } + } + + CATEGORY = "generators" + OUTPUT_IS_LIST = (True,) + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, start: float, end: float, step: float) -> tuple[list[float]]: + if step == 0: + return [], + + steps = (Decimal(str(end)) - Decimal(str(start))) / Decimal(str(step)) + return [start + i * step for i in range(int(steps))], + + +class FloatRange2(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "start": ("FLOAT", {"default": 0.0, "step": 0.001}), + "end": ("FLOAT", {"default": 1.0, "step": 0.001}), + "fence_posts": ("INT", {"default": 2, "min": 0}), + } + } + + CATEGORY = "generators" + OUTPUT_IS_LIST = (True,) + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, start: float, end: float, fence_posts: int) -> tuple[list[float]]: + if fence_posts == 0: + return [], + elif fence_posts == 1: + return [start], + elif fence_posts == 2: + return [start, end], + + step = (end - start) / (fence_posts - 1) + return [start + i * step for i in range(fence_posts)], + + +class FloatRange3(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "start": ("FLOAT", {"default": 0.0, "step": 0.001}), + "end": ("FLOAT", {"default": 1.0, "step": 0.001}), + "spans": ("INT", {"default": 1, "min": 0}), + } + } + + CATEGORY = "generators" + OUTPUT_IS_LIST = (True,) + RETURN_TYPES = ("FLOAT",) + FUNCTION = "execute" + + def execute(self, start: float, end: float, spans: int) -> tuple[list[float]]: + if spans == 0: + return [], + elif spans == 1: + return [start], + + span_width = (end - start) / spans + return [start + i * span_width for i in range(spans)], + + +class StringSplit(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "value": ("STRING", {}), + "delimiter": ("STRING", {"default": ","}), + } + } + + CATEGORY = "generators" + OUTPUT_IS_LIST = (True,) + RETURN_TYPES = ("STRING",) + FUNCTION = "execute" + + def execute(self, value: str = "", delimiter: str = ",") -> tuple[list[str]]: + return value.split(delimiter), + + +export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index 2c165dc66..ab3eb52fe 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -29,7 +29,9 @@ from natsort import natsorted from torch import Tensor from comfy.cmd import folder_paths +from comfy.comfy_types import IO from comfy.digest import digest +from comfy.node_helpers import export_custom_nodes from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \ InputTypeSpec, ValidatedNodeResult @@ -301,6 +303,25 @@ class StringJoin(CustomNode): return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),) +class StringJoin1(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + optional = {f"value{i}": (IO.ANY, {}) for i in range(5)} + optional["separator"] = (IO.STRING, {"default": "_"}) + return { + "required": {}, + "optional": optional + } + + RETURN_TYPES = ("STRING",) + CATEGORY = "api/openapi" + FUNCTION = "execute" + + def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult: + sorted_keys = natsorted(kwargs.keys()) + return (separator.join([str(kwargs[key]) for key in sorted_keys if kwargs[key] is not None]),) + + class StringToUri(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: @@ -711,25 +732,4 @@ class ImageRequestParameter(CustomNode): return (output_image,) -NODE_CLASS_MAPPINGS = {} -for cls in ( - IntRequestParameter, - FloatRequestParameter, - StringRequestParameter, - StringEnumRequestParameter, - BooleanRequestParameter, - HashImage, - StringPosixPathJoin, - LegacyOutputURIs, - DevNullUris, - StringJoin, - StringToUri, - UriFormat, - ImageExif, - ImageExifMerge, - ImageExifUncommon, - ImageExifCreationDateAndBatchNumber, - SaveImagesResponse, - ImageRequestParameter -): - NODE_CLASS_MAPPINGS[cls.__name__] = cls +export_custom_nodes() diff --git a/tests/unit/test_generators.py b/tests/unit/test_generators.py new file mode 100644 index 000000000..d4a301492 --- /dev/null +++ b/tests/unit/test_generators.py @@ -0,0 +1,266 @@ +import pytest + +from comfy_extras.nodes.nodes_generators import IntRange, FloatRange1, FloatRange2, FloatRange3, StringSplit + + +def test_int_range(): + node = IntRange() + + # Basic range + result = node.execute(0, 5, 1) + assert result == ([0, 1, 2, 3, 4],) + + # Negative step + result = node.execute(5, 0, -1) + assert result == ([5, 4, 3, 2, 1],) + + # Empty range (step in wrong direction) + result = node.execute(0, 5, -1) + assert result == ([],) + + # Single item range + result = node.execute(0, 1, 1) + assert result == ([0],) + + # Step size > 1 + result = node.execute(0, 10, 2) + assert result == ([0, 2, 4, 6, 8],) + + +def test_float_range1(): + node = FloatRange1() + + # Basic range + result = node.execute(0.0, 1.0, 0.25) + assert result == ([0.0, 0.25, 0.5, 0.75],) + + # Step size of 1 + result = node.execute(0.0, 3.0, 1.0) + assert result == ([0.0, 1.0, 2.0],) + + # Zero step + result = node.execute(0.0, 1.0, 0.0) + assert result == ([],) + + # Negative step + result = node.execute(1.0, 0.0, -0.25) + assert result == ([1.0, 0.75, 0.5, 0.25],) + + # Test floating point precision + result = node.execute(0.0, 0.3, 0.1) + assert len(result[0]) == 3 + assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2])) + + +def test_float_range2(): + node = FloatRange2() + + # Basic range with 5 points + result = node.execute(0.0, 1.0, 5) + assert result == ([0.0, 0.25, 0.5, 0.75, 1.0],) + + # Zero points + result = node.execute(0.0, 1.0, 0) + assert result == ([],) + + # One point + result = node.execute(0.0, 1.0, 1) + assert result == ([0.0],) + + # Two points + result = node.execute(0.0, 1.0, 2) + assert result == ([0.0, 1.0],) + + # Test negative range + result = node.execute(1.0, -1.0, 3) + assert result == ([1.0, 0.0, -1.0],) + + # Test floating point precision + result = node.execute(0.0, 0.2, 3) + assert len(result[0]) == 3 + assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2])) + + +def test_float_range3(): + node = FloatRange3() + + # Basic range with 4 spans + result = node.execute(0.0, 1.0, 4) + expected = [0.0, 0.25, 0.5, 0.75] # Note: doesn't include end point + assert result == (expected,) + + # Zero spans + result = node.execute(0.0, 1.0, 0) + assert result == ([],) + + # One span + result = node.execute(0.0, 1.0, 1) + assert result == ([0.0],) + + # Test negative range + result = node.execute(1.0, -1.0, 2) + assert result == ([1.0, 0.0],) + + # Test floating point precision + result = node.execute(0.0, 0.3, 3) + assert len(result[0]) == 3 + assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2])) + + +def test_output_types(): + """Test that all nodes return correct types and list outputs""" + # IntRange + result = IntRange().execute(0, 5, 1) + assert isinstance(result, tuple) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, int) for x in result[0]) + + # FloatRange1 + result = FloatRange1().execute(0.0, 1.0, 0.5) + assert isinstance(result, tuple) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0]) + + # FloatRange2 + result = FloatRange2().execute(0.0, 1.0, 3) + assert isinstance(result, tuple) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0]) + + # FloatRange3 + result = FloatRange3().execute(0.0, 1.0, 2) + assert isinstance(result, tuple) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, float) for x in result[0]) + + +@pytest.mark.parametrize("node_class", [IntRange, FloatRange1, FloatRange2, FloatRange3]) +def test_node_metadata(node_class): + """Test that all nodes have correct metadata""" + node = node_class() + + # Check required class attributes + assert hasattr(node, 'CATEGORY') + assert hasattr(node, 'OUTPUT_IS_LIST') + assert hasattr(node, 'RETURN_TYPES') + assert hasattr(node, 'FUNCTION') + + # Verify OUTPUT_IS_LIST matches number of outputs + assert len(node.OUTPUT_IS_LIST) == len(node.RETURN_TYPES) + assert node.OUTPUT_IS_LIST == (True,) + + # Verify input types exist + input_types = node.INPUT_TYPES() + assert 'required' in input_types + + # All nodes should have appropriate numeric types + if node_class == IntRange: + assert all(v[0] == 'INT' for v in input_types['required'].values()) + else: + assert input_types['required']['start'][0] == 'FLOAT' + assert input_types['required']['end'][0] == 'FLOAT' + + +def test_string_split_basic(): + node = StringSplit() + + # Basic comma split + result = node.execute("a,b,c") + assert result == (["a", "b", "c"],) + + # Custom delimiter + result = node.execute("a|b|c", delimiter="|") + assert result == (["a", "b", "c"],) + + # Empty string + result = node.execute("") + assert result == ([""],) + + # Single value (no delimiters) + result = node.execute("abc") + assert result == (["abc"],) + + +def test_string_split_edge_cases(): + node = StringSplit() + + # Multiple consecutive delimiters + result = node.execute("a,,b,,c") + assert result == (["a", "", "b", "", "c"],) + + # Leading/trailing delimiters + result = node.execute(",a,b,c,") + assert result == (["", "a", "b", "c", ""],) + + # Multi-character delimiter + result = node.execute("a<->b<->c", delimiter="<->") + assert result == (["a", "b", "c"],) + + # Whitespace handling + result = node.execute(" a , b , c ") + assert result == ([" a ", " b ", " c "],) + + # Split on whitespace + result = node.execute("a b c", delimiter=" ") + assert result == (["a", "b", "c"],) + + +def test_string_split_special_chars(): + node = StringSplit() + + # Split on newline + result = node.execute("a\nb\nc", delimiter="\n") + assert result == (["a", "b", "c"],) + + # Split on tab + result = node.execute("a\tb\tc", delimiter="\t") + assert result == (["a", "b", "c"],) + + # Regex special characters as delimiters + result = node.execute("a.b.c", delimiter=".") + assert result == (["a", "b", "c"],) + + # Unicode delimiters + result = node.execute("a→b→c", delimiter="→") + assert result == (["a", "b", "c"],) + + +def test_string_split_metadata(): + node = StringSplit() + + # Check required class attributes + assert hasattr(node, 'CATEGORY') + assert hasattr(node, 'OUTPUT_IS_LIST') + assert hasattr(node, 'RETURN_TYPES') + assert hasattr(node, 'FUNCTION') + + # Verify OUTPUT_IS_LIST matches RETURN_TYPES + assert len(node.OUTPUT_IS_LIST) == len(node.RETURN_TYPES) + assert node.OUTPUT_IS_LIST == (True,) + assert node.RETURN_TYPES == ("STRING",) + + # Check input types + input_types = node.INPUT_TYPES() + assert 'required' in input_types + assert 'value' in input_types['required'] + assert 'delimiter' in input_types['required'] + assert input_types['required']['value'][0] == 'STRING' + assert input_types['required']['delimiter'][0] == 'STRING' + + # Check delimiter default + assert input_types['required']['delimiter'][1]['default'] == ',' + + +def test_string_split_return_type(): + node = StringSplit() + + # Verify return structure + result = node.execute("a,b,c") + assert isinstance(result, tuple) + assert len(result) == 1 + assert isinstance(result[0], list) + assert all(isinstance(x, str) for x in result[0])