mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Generators
This commit is contained in:
parent
0cfde0ad6d
commit
42f75cadc0
124
comfy_extras/nodes/nodes_generators.py
Normal file
124
comfy_extras/nodes/nodes_generators.py
Normal file
@ -0,0 +1,124 @@
|
||||
from decimal import Decimal
|
||||
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
|
||||
class IntRange(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"start": ("INT", {"default": 0, "step": 0.001}),
|
||||
"end": ("INT", {"default": 1, "step": 0.001}),
|
||||
"step": ("INT", {"default": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "generators"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, start: int, end: int, step: int) -> tuple[list[int]]:
|
||||
return list(range(start, end, step)),
|
||||
|
||||
|
||||
class FloatRange1(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"start": ("FLOAT", {"default": 0.0, "step": 0.001}),
|
||||
"end": ("FLOAT", {"default": 1.0, "step": 0.001}),
|
||||
"step": ("FLOAT", {"default": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "generators"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, start: float, end: float, step: float) -> tuple[list[float]]:
|
||||
if step == 0:
|
||||
return [],
|
||||
|
||||
steps = (Decimal(str(end)) - Decimal(str(start))) / Decimal(str(step))
|
||||
return [start + i * step for i in range(int(steps))],
|
||||
|
||||
|
||||
class FloatRange2(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"start": ("FLOAT", {"default": 0.0, "step": 0.001}),
|
||||
"end": ("FLOAT", {"default": 1.0, "step": 0.001}),
|
||||
"fence_posts": ("INT", {"default": 2, "min": 0}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "generators"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, start: float, end: float, fence_posts: int) -> tuple[list[float]]:
|
||||
if fence_posts == 0:
|
||||
return [],
|
||||
elif fence_posts == 1:
|
||||
return [start],
|
||||
elif fence_posts == 2:
|
||||
return [start, end],
|
||||
|
||||
step = (end - start) / (fence_posts - 1)
|
||||
return [start + i * step for i in range(fence_posts)],
|
||||
|
||||
|
||||
class FloatRange3(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"start": ("FLOAT", {"default": 0.0, "step": 0.001}),
|
||||
"end": ("FLOAT", {"default": 1.0, "step": 0.001}),
|
||||
"spans": ("INT", {"default": 1, "min": 0}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "generators"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, start: float, end: float, spans: int) -> tuple[list[float]]:
|
||||
if spans == 0:
|
||||
return [],
|
||||
elif spans == 1:
|
||||
return [start],
|
||||
|
||||
span_width = (end - start) / spans
|
||||
return [start + i * span_width for i in range(spans)],
|
||||
|
||||
|
||||
class StringSplit(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {}),
|
||||
"delimiter": ("STRING", {"default": ","}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "generators"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, value: str = "", delimiter: str = ",") -> tuple[list[str]]:
|
||||
return value.split(delimiter),
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
@ -29,7 +29,9 @@ from natsort import natsorted
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.comfy_types import IO
|
||||
from comfy.digest import digest
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
|
||||
InputTypeSpec, ValidatedNodeResult
|
||||
|
||||
@ -301,6 +303,25 @@ class StringJoin(CustomNode):
|
||||
return (separator.join([kwargs[key] for key in sorted_keys if kwargs[key] != ""]),)
|
||||
|
||||
|
||||
class StringJoin1(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
optional = {f"value{i}": (IO.ANY, {}) for i in range(5)}
|
||||
optional["separator"] = (IO.STRING, {"default": "_"})
|
||||
return {
|
||||
"required": {},
|
||||
"optional": optional
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
CATEGORY = "api/openapi"
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
|
||||
sorted_keys = natsorted(kwargs.keys())
|
||||
return (separator.join([str(kwargs[key]) for key in sorted_keys if kwargs[key] is not None]),)
|
||||
|
||||
|
||||
class StringToUri(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
@ -711,25 +732,4 @@ class ImageRequestParameter(CustomNode):
|
||||
return (output_image,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
IntRequestParameter,
|
||||
FloatRequestParameter,
|
||||
StringRequestParameter,
|
||||
StringEnumRequestParameter,
|
||||
BooleanRequestParameter,
|
||||
HashImage,
|
||||
StringPosixPathJoin,
|
||||
LegacyOutputURIs,
|
||||
DevNullUris,
|
||||
StringJoin,
|
||||
StringToUri,
|
||||
UriFormat,
|
||||
ImageExif,
|
||||
ImageExifMerge,
|
||||
ImageExifUncommon,
|
||||
ImageExifCreationDateAndBatchNumber,
|
||||
SaveImagesResponse,
|
||||
ImageRequestParameter
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
export_custom_nodes()
|
||||
|
||||
266
tests/unit/test_generators.py
Normal file
266
tests/unit/test_generators.py
Normal file
@ -0,0 +1,266 @@
|
||||
import pytest
|
||||
|
||||
from comfy_extras.nodes.nodes_generators import IntRange, FloatRange1, FloatRange2, FloatRange3, StringSplit
|
||||
|
||||
|
||||
def test_int_range():
|
||||
node = IntRange()
|
||||
|
||||
# Basic range
|
||||
result = node.execute(0, 5, 1)
|
||||
assert result == ([0, 1, 2, 3, 4],)
|
||||
|
||||
# Negative step
|
||||
result = node.execute(5, 0, -1)
|
||||
assert result == ([5, 4, 3, 2, 1],)
|
||||
|
||||
# Empty range (step in wrong direction)
|
||||
result = node.execute(0, 5, -1)
|
||||
assert result == ([],)
|
||||
|
||||
# Single item range
|
||||
result = node.execute(0, 1, 1)
|
||||
assert result == ([0],)
|
||||
|
||||
# Step size > 1
|
||||
result = node.execute(0, 10, 2)
|
||||
assert result == ([0, 2, 4, 6, 8],)
|
||||
|
||||
|
||||
def test_float_range1():
|
||||
node = FloatRange1()
|
||||
|
||||
# Basic range
|
||||
result = node.execute(0.0, 1.0, 0.25)
|
||||
assert result == ([0.0, 0.25, 0.5, 0.75],)
|
||||
|
||||
# Step size of 1
|
||||
result = node.execute(0.0, 3.0, 1.0)
|
||||
assert result == ([0.0, 1.0, 2.0],)
|
||||
|
||||
# Zero step
|
||||
result = node.execute(0.0, 1.0, 0.0)
|
||||
assert result == ([],)
|
||||
|
||||
# Negative step
|
||||
result = node.execute(1.0, 0.0, -0.25)
|
||||
assert result == ([1.0, 0.75, 0.5, 0.25],)
|
||||
|
||||
# Test floating point precision
|
||||
result = node.execute(0.0, 0.3, 0.1)
|
||||
assert len(result[0]) == 3
|
||||
assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2]))
|
||||
|
||||
|
||||
def test_float_range2():
|
||||
node = FloatRange2()
|
||||
|
||||
# Basic range with 5 points
|
||||
result = node.execute(0.0, 1.0, 5)
|
||||
assert result == ([0.0, 0.25, 0.5, 0.75, 1.0],)
|
||||
|
||||
# Zero points
|
||||
result = node.execute(0.0, 1.0, 0)
|
||||
assert result == ([],)
|
||||
|
||||
# One point
|
||||
result = node.execute(0.0, 1.0, 1)
|
||||
assert result == ([0.0],)
|
||||
|
||||
# Two points
|
||||
result = node.execute(0.0, 1.0, 2)
|
||||
assert result == ([0.0, 1.0],)
|
||||
|
||||
# Test negative range
|
||||
result = node.execute(1.0, -1.0, 3)
|
||||
assert result == ([1.0, 0.0, -1.0],)
|
||||
|
||||
# Test floating point precision
|
||||
result = node.execute(0.0, 0.2, 3)
|
||||
assert len(result[0]) == 3
|
||||
assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2]))
|
||||
|
||||
|
||||
def test_float_range3():
|
||||
node = FloatRange3()
|
||||
|
||||
# Basic range with 4 spans
|
||||
result = node.execute(0.0, 1.0, 4)
|
||||
expected = [0.0, 0.25, 0.5, 0.75] # Note: doesn't include end point
|
||||
assert result == (expected,)
|
||||
|
||||
# Zero spans
|
||||
result = node.execute(0.0, 1.0, 0)
|
||||
assert result == ([],)
|
||||
|
||||
# One span
|
||||
result = node.execute(0.0, 1.0, 1)
|
||||
assert result == ([0.0],)
|
||||
|
||||
# Test negative range
|
||||
result = node.execute(1.0, -1.0, 2)
|
||||
assert result == ([1.0, 0.0],)
|
||||
|
||||
# Test floating point precision
|
||||
result = node.execute(0.0, 0.3, 3)
|
||||
assert len(result[0]) == 3
|
||||
assert all(abs(a - b) < 1e-10 for a, b in zip(result[0], [0.0, 0.1, 0.2]))
|
||||
|
||||
|
||||
def test_output_types():
|
||||
"""Test that all nodes return correct types and list outputs"""
|
||||
# IntRange
|
||||
result = IntRange().execute(0, 5, 1)
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], list)
|
||||
assert all(isinstance(x, int) for x in result[0])
|
||||
|
||||
# FloatRange1
|
||||
result = FloatRange1().execute(0.0, 1.0, 0.5)
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], list)
|
||||
assert all(isinstance(x, float) for x in result[0])
|
||||
|
||||
# FloatRange2
|
||||
result = FloatRange2().execute(0.0, 1.0, 3)
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], list)
|
||||
assert all(isinstance(x, float) for x in result[0])
|
||||
|
||||
# FloatRange3
|
||||
result = FloatRange3().execute(0.0, 1.0, 2)
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], list)
|
||||
assert all(isinstance(x, float) for x in result[0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("node_class", [IntRange, FloatRange1, FloatRange2, FloatRange3])
|
||||
def test_node_metadata(node_class):
|
||||
"""Test that all nodes have correct metadata"""
|
||||
node = node_class()
|
||||
|
||||
# Check required class attributes
|
||||
assert hasattr(node, 'CATEGORY')
|
||||
assert hasattr(node, 'OUTPUT_IS_LIST')
|
||||
assert hasattr(node, 'RETURN_TYPES')
|
||||
assert hasattr(node, 'FUNCTION')
|
||||
|
||||
# Verify OUTPUT_IS_LIST matches number of outputs
|
||||
assert len(node.OUTPUT_IS_LIST) == len(node.RETURN_TYPES)
|
||||
assert node.OUTPUT_IS_LIST == (True,)
|
||||
|
||||
# Verify input types exist
|
||||
input_types = node.INPUT_TYPES()
|
||||
assert 'required' in input_types
|
||||
|
||||
# All nodes should have appropriate numeric types
|
||||
if node_class == IntRange:
|
||||
assert all(v[0] == 'INT' for v in input_types['required'].values())
|
||||
else:
|
||||
assert input_types['required']['start'][0] == 'FLOAT'
|
||||
assert input_types['required']['end'][0] == 'FLOAT'
|
||||
|
||||
|
||||
def test_string_split_basic():
|
||||
node = StringSplit()
|
||||
|
||||
# Basic comma split
|
||||
result = node.execute("a,b,c")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
# Custom delimiter
|
||||
result = node.execute("a|b|c", delimiter="|")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
# Empty string
|
||||
result = node.execute("")
|
||||
assert result == ([""],)
|
||||
|
||||
# Single value (no delimiters)
|
||||
result = node.execute("abc")
|
||||
assert result == (["abc"],)
|
||||
|
||||
|
||||
def test_string_split_edge_cases():
|
||||
node = StringSplit()
|
||||
|
||||
# Multiple consecutive delimiters
|
||||
result = node.execute("a,,b,,c")
|
||||
assert result == (["a", "", "b", "", "c"],)
|
||||
|
||||
# Leading/trailing delimiters
|
||||
result = node.execute(",a,b,c,")
|
||||
assert result == (["", "a", "b", "c", ""],)
|
||||
|
||||
# Multi-character delimiter
|
||||
result = node.execute("a<->b<->c", delimiter="<->")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
# Whitespace handling
|
||||
result = node.execute(" a , b , c ")
|
||||
assert result == ([" a ", " b ", " c "],)
|
||||
|
||||
# Split on whitespace
|
||||
result = node.execute("a b c", delimiter=" ")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
|
||||
def test_string_split_special_chars():
|
||||
node = StringSplit()
|
||||
|
||||
# Split on newline
|
||||
result = node.execute("a\nb\nc", delimiter="\n")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
# Split on tab
|
||||
result = node.execute("a\tb\tc", delimiter="\t")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
# Regex special characters as delimiters
|
||||
result = node.execute("a.b.c", delimiter=".")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
# Unicode delimiters
|
||||
result = node.execute("a→b→c", delimiter="→")
|
||||
assert result == (["a", "b", "c"],)
|
||||
|
||||
|
||||
def test_string_split_metadata():
|
||||
node = StringSplit()
|
||||
|
||||
# Check required class attributes
|
||||
assert hasattr(node, 'CATEGORY')
|
||||
assert hasattr(node, 'OUTPUT_IS_LIST')
|
||||
assert hasattr(node, 'RETURN_TYPES')
|
||||
assert hasattr(node, 'FUNCTION')
|
||||
|
||||
# Verify OUTPUT_IS_LIST matches RETURN_TYPES
|
||||
assert len(node.OUTPUT_IS_LIST) == len(node.RETURN_TYPES)
|
||||
assert node.OUTPUT_IS_LIST == (True,)
|
||||
assert node.RETURN_TYPES == ("STRING",)
|
||||
|
||||
# Check input types
|
||||
input_types = node.INPUT_TYPES()
|
||||
assert 'required' in input_types
|
||||
assert 'value' in input_types['required']
|
||||
assert 'delimiter' in input_types['required']
|
||||
assert input_types['required']['value'][0] == 'STRING'
|
||||
assert input_types['required']['delimiter'][0] == 'STRING'
|
||||
|
||||
# Check delimiter default
|
||||
assert input_types['required']['delimiter'][1]['default'] == ','
|
||||
|
||||
|
||||
def test_string_split_return_type():
|
||||
node = StringSplit()
|
||||
|
||||
# Verify return structure
|
||||
result = node.execute("a,b,c")
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], list)
|
||||
assert all(isinstance(x, str) for x in result[0])
|
||||
Loading…
Reference in New Issue
Block a user