Refactor math nodes to use MatchType and Autogrow for generic operations

This commit is contained in:
jackdevis 2026-01-21 03:57:38 -08:00
parent c6f0f31d44
commit 27f600cd8b
2 changed files with 248 additions and 139 deletions

View File

@ -3,99 +3,149 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
class IntAdd(io.ComfyNode): class Add(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
template = io.MatchType.Template("add")
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("operand", template=template),
prefix="operand",
min=2,
max=10
)
return io.Schema( return io.Schema(
node_id="IntAdd", node_id="Add",
display_name="Int Add", display_name="Add",
category="math", category="math",
inputs=[ inputs=[
io.Int.Input("a"), io.Autogrow.Input("operands", template=autogrow_template)
io.Int.Input("b"), ],
outputs=[
io.MatchType.Output(template=template)
], ],
outputs=[io.Int.Output()],
) )
@classmethod @classmethod
def execute(cls, a: int, b: int) -> io.NodeOutput: def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(a + b) values = list(operands.values())
result = values[0]
for value in values[1:]:
result = result + value
return io.NodeOutput(result)
class IntSubtract(io.ComfyNode): class Subtract(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
template = io.MatchType.Template("subtract")
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("operand", template=template),
prefix="operand",
min=2,
max=10
)
return io.Schema( return io.Schema(
node_id="IntSubtract", node_id="Subtract",
display_name="Int Subtract", display_name="Subtract",
category="math", category="math",
inputs=[ inputs=[
io.Int.Input("a"), io.Autogrow.Input("operands", template=autogrow_template)
io.Int.Input("b"), ],
outputs=[
io.MatchType.Output(template=template)
], ],
outputs=[io.Int.Output()],
) )
@classmethod @classmethod
def execute(cls, a: int, b: int) -> io.NodeOutput: def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(a - b) values = list(operands.values())
result = values[0]
for value in values[1:]:
result = result - value
return io.NodeOutput(result)
class IntMultiply(io.ComfyNode): class Multiply(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
template = io.MatchType.Template("multiply")
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("operand", template=template),
prefix="operand",
min=2,
max=10
)
return io.Schema( return io.Schema(
node_id="IntMultiply", node_id="Multiply",
display_name="Int Multiply", display_name="Multiply",
category="math", category="math",
inputs=[ inputs=[
io.Int.Input("a"), io.Autogrow.Input("operands", template=autogrow_template)
io.Int.Input("b"), ],
outputs=[
io.MatchType.Output(template=template)
], ],
outputs=[io.Int.Output()],
) )
@classmethod @classmethod
def execute(cls, a: int, b: int) -> io.NodeOutput: def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
return io.NodeOutput(a * b) values = list(operands.values())
result = values[0]
for value in values[1:]:
result = result * value
return io.NodeOutput(result)
class IntDivide(io.ComfyNode): class Divide(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
template = io.MatchType.Template("divide")
autogrow_template = io.Autogrow.TemplatePrefix(
io.MatchType.Input("operand", template=template),
prefix="operand",
min=2,
max=10
)
return io.Schema( return io.Schema(
node_id="IntDivide", node_id="Divide",
display_name="Int Divide", display_name="Divide",
category="math", category="math",
inputs=[ inputs=[
io.Int.Input("a"), io.Autogrow.Input("operands", template=autogrow_template)
io.Int.Input("b"), ],
outputs=[
io.MatchType.Output(template=template)
], ],
outputs=[io.Int.Output()],
) )
@classmethod @classmethod
def validate_inputs(cls, a: int, b: int) -> bool: def validate_inputs(cls, operands: io.Autogrow.Type) -> bool:
if b == 0: values = list(operands.values())
return "Division by zero is not allowed" # Check for division by zero in any divisor (all operands except the first)
for value in values[1:]:
if value == 0:
return "Division by zero is not allowed"
return True return True
@classmethod @classmethod
def execute(cls, a: int, b: int) -> io.NodeOutput: def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
if b == 0: values = list(operands.values())
raise ValueError("Division by zero is not allowed") result = values[0]
# Integer division (floor division) for value in values[1:]:
return io.NodeOutput(a // b) if value == 0:
raise ValueError("Division by zero is not allowed")
result = result / value
return io.NodeOutput(result)
class MathExtension(ComfyExtension): class MathExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
IntAdd, Add,
IntSubtract, Subtract,
IntMultiply, Multiply,
IntDivide, Divide,
] ]

View File

@ -1,184 +1,243 @@
import pytest import pytest
from comfy_extras.nodes_math import IntAdd, IntSubtract, IntMultiply, IntDivide from comfy_extras.nodes_math import Add, Subtract, Multiply, Divide
from comfy_api.latest import io from comfy_api.latest import io
class TestIntAdd: class TestAdd:
"""Test suite for IntAdd node.""" """Test suite for Add node."""
def test_basic_addition(self): def test_basic_addition_integers(self):
"""Test basic addition of two positive integers.""" """Test basic addition of two integers."""
result = IntAdd.execute(5, 3) operands = {"operand0": 5, "operand1": 3}
result = Add.execute(operands)
assert result.value == 8 assert result.value == 8
assert isinstance(result, io.NodeOutput) assert isinstance(result, io.NodeOutput)
def test_addition_multiple_integers(self):
"""Test addition with multiple integers."""
operands = {"operand0": 5, "operand1": 3, "operand2": 2}
result = Add.execute(operands)
assert result.value == 10
def test_addition_with_zero(self): def test_addition_with_zero(self):
"""Test addition with zero.""" """Test addition with zero."""
result = IntAdd.execute(10, 0) operands = {"operand0": 10, "operand1": 0}
result = Add.execute(operands)
assert result.value == 10 assert result.value == 10
def test_addition_with_negative(self): def test_addition_with_negative(self):
"""Test addition with negative numbers.""" """Test addition with negative numbers."""
result = IntAdd.execute(5, -3) operands = {"operand0": 5, "operand1": -3}
result = Add.execute(operands)
assert result.value == 2 assert result.value == 2
def test_addition_both_negative(self): def test_addition_floats(self):
"""Test addition of two negative numbers.""" """Test addition with floats."""
result = IntAdd.execute(-5, -3) operands = {"operand0": 5.5, "operand1": 3.2}
assert result.value == -8 result = Add.execute(operands)
assert abs(result.value - 8.7) < 1e-10
def test_addition_large_numbers(self): def test_addition_strings(self):
"""Test addition with large numbers.""" """Test addition with strings (concatenation)."""
result = IntAdd.execute(1000000, 2000000) operands = {"operand0": "Hello", "operand1": " World"}
assert result.value == 3000000 result = Add.execute(operands)
assert result.value == "Hello World"
def test_addition_lists(self):
"""Test addition with lists (like conditionings)."""
operands = {"operand0": [1, 2, 3], "operand1": [4, 5]}
result = Add.execute(operands)
assert result.value == [1, 2, 3, 4, 5]
class TestIntSubtract: class TestSubtract:
"""Test suite for IntSubtract node.""" """Test suite for Subtract node."""
def test_basic_subtraction(self): def test_basic_subtraction_integers(self):
"""Test basic subtraction of two positive integers.""" """Test basic subtraction of two integers."""
result = IntSubtract.execute(10, 3) operands = {"operand0": 10, "operand1": 3}
result = Subtract.execute(operands)
assert result.value == 7 assert result.value == 7
assert isinstance(result, io.NodeOutput) assert isinstance(result, io.NodeOutput)
def test_subtraction_multiple_integers(self):
"""Test subtraction with multiple integers."""
operands = {"operand0": 20, "operand1": 5, "operand2": 3}
result = Subtract.execute(operands)
assert result.value == 12
def test_subtraction_with_zero(self): def test_subtraction_with_zero(self):
"""Test subtraction with zero.""" """Test subtraction with zero."""
result = IntSubtract.execute(10, 0) operands = {"operand0": 10, "operand1": 0}
result = Subtract.execute(operands)
assert result.value == 10 assert result.value == 10
def test_subtraction_resulting_negative(self): def test_subtraction_resulting_negative(self):
"""Test subtraction that results in negative number.""" """Test subtraction that results in negative number."""
result = IntSubtract.execute(3, 10) operands = {"operand0": 3, "operand1": 10}
result = Subtract.execute(operands)
assert result.value == -7 assert result.value == -7
def test_subtraction_with_negative(self): def test_subtraction_with_negative(self):
"""Test subtraction with negative numbers.""" """Test subtraction with negative numbers."""
result = IntSubtract.execute(5, -3) operands = {"operand0": 5, "operand1": -3}
result = Subtract.execute(operands)
assert result.value == 8 assert result.value == 8
def test_subtraction_both_negative(self): def test_subtraction_floats(self):
"""Test subtraction of two negative numbers.""" """Test subtraction with floats."""
result = IntSubtract.execute(-5, -3) operands = {"operand0": 10.5, "operand1": 3.2}
assert result.value == -2 result = Subtract.execute(operands)
assert abs(result.value - 7.3) < 1e-10
class TestIntMultiply: class TestMultiply:
"""Test suite for IntMultiply node.""" """Test suite for Multiply node."""
def test_basic_multiplication(self): def test_basic_multiplication_integers(self):
"""Test basic multiplication of two positive integers.""" """Test basic multiplication of two integers."""
result = IntMultiply.execute(5, 3) operands = {"operand0": 5, "operand1": 3}
result = Multiply.execute(operands)
assert result.value == 15 assert result.value == 15
assert isinstance(result, io.NodeOutput) assert isinstance(result, io.NodeOutput)
def test_multiplication_multiple_integers(self):
"""Test multiplication with multiple integers."""
operands = {"operand0": 2, "operand1": 3, "operand2": 4}
result = Multiply.execute(operands)
assert result.value == 24
def test_multiplication_with_zero(self): def test_multiplication_with_zero(self):
"""Test multiplication with zero.""" """Test multiplication with zero."""
result = IntMultiply.execute(10, 0) operands = {"operand0": 10, "operand1": 0}
result = Multiply.execute(operands)
assert result.value == 0 assert result.value == 0
def test_multiplication_with_one(self): def test_multiplication_with_one(self):
"""Test multiplication with one.""" """Test multiplication with one."""
result = IntMultiply.execute(10, 1) operands = {"operand0": 10, "operand1": 1}
result = Multiply.execute(operands)
assert result.value == 10 assert result.value == 10
def test_multiplication_with_negative(self): def test_multiplication_with_negative(self):
"""Test multiplication with negative numbers.""" """Test multiplication with negative numbers."""
result = IntMultiply.execute(5, -3) operands = {"operand0": 5, "operand1": -3}
result = Multiply.execute(operands)
assert result.value == -15 assert result.value == -15
def test_multiplication_both_negative(self): def test_multiplication_floats(self):
"""Test multiplication of two negative numbers.""" """Test multiplication with floats."""
result = IntMultiply.execute(-5, -3) operands = {"operand0": 2.5, "operand1": 4.0}
assert result.value == 15 result = Multiply.execute(operands)
assert abs(result.value - 10.0) < 1e-10
def test_multiplication_large_numbers(self): def test_multiplication_string_repetition(self):
"""Test multiplication with large numbers.""" """Test multiplication with string (repetition)."""
result = IntMultiply.execute(1000, 2000) operands = {"operand0": "Hello", "operand1": 3}
assert result.value == 2000000 result = Multiply.execute(operands)
assert result.value == "HelloHelloHello"
class TestIntDivide: class TestDivide:
"""Test suite for IntDivide node.""" """Test suite for Divide node."""
def test_basic_division(self): def test_basic_division_integers(self):
"""Test basic division of two positive integers.""" """Test basic division of two integers."""
result = IntDivide.execute(10, 2) operands = {"operand0": 10, "operand1": 2}
assert result.value == 5 result = Divide.execute(operands)
assert result.value == 5.0 # Division returns float
assert isinstance(result, io.NodeOutput) assert isinstance(result, io.NodeOutput)
def test_division_multiple_integers(self):
"""Test division with multiple integers."""
operands = {"operand0": 100, "operand1": 2, "operand2": 5}
result = Divide.execute(operands)
assert result.value == 10.0
def test_division_with_remainder(self): def test_division_with_remainder(self):
"""Test division that results in floor division.""" """Test division that results in float."""
result = IntDivide.execute(10, 3) operands = {"operand0": 10, "operand1": 3}
assert result.value == 3 # Floor division: 10 // 3 = 3 result = Divide.execute(operands)
assert abs(result.value - 3.3333333333333335) < 1e-10
def test_division_with_negative(self): def test_division_with_negative(self):
"""Test division with negative numbers.""" """Test division with negative numbers."""
result = IntDivide.execute(10, -3) operands = {"operand0": 10, "operand1": -2}
assert result.value == -4 # Floor division: 10 // -3 = -4 result = Divide.execute(operands)
assert result.value == -5.0
def test_division_both_negative(self):
"""Test division of two negative numbers."""
result = IntDivide.execute(-10, -3)
assert result.value == 3 # Floor division: -10 // -3 = 3
def test_division_by_one(self): def test_division_by_one(self):
"""Test division by one.""" """Test division by one."""
result = IntDivide.execute(10, 1) operands = {"operand0": 10, "operand1": 1}
assert result.value == 10 result = Divide.execute(operands)
assert result.value == 10.0
def test_division_floats(self):
"""Test division with floats."""
operands = {"operand0": 10.5, "operand1": 2.5}
result = Divide.execute(operands)
assert abs(result.value - 4.2) < 1e-10
def test_validate_inputs_division_by_zero(self): def test_validate_inputs_division_by_zero(self):
"""Test that division by zero is caught in validation.""" """Test that division by zero is caught in validation."""
validation_result = IntDivide.validate_inputs(10, 0) operands = {"operand0": 10, "operand1": 0}
validation_result = Divide.validate_inputs(operands)
assert validation_result == "Division by zero is not allowed"
def test_validate_inputs_division_by_zero_multiple(self):
"""Test that division by zero in later operands is caught."""
operands = {"operand0": 10, "operand1": 2, "operand2": 0}
validation_result = Divide.validate_inputs(operands)
assert validation_result == "Division by zero is not allowed" assert validation_result == "Division by zero is not allowed"
def test_validate_inputs_normal_division(self): def test_validate_inputs_normal_division(self):
"""Test that normal division passes validation.""" """Test that normal division passes validation."""
validation_result = IntDivide.validate_inputs(10, 2) operands = {"operand0": 10, "operand1": 2}
validation_result = Divide.validate_inputs(operands)
assert validation_result is True assert validation_result is True
def test_execute_division_by_zero_raises_error(self): def test_execute_division_by_zero_raises_error(self):
"""Test that division by zero raises ValueError in execute.""" """Test that division by zero raises ValueError in execute."""
operands = {"operand0": 10, "operand1": 0}
with pytest.raises(ValueError, match="Division by zero is not allowed"): with pytest.raises(ValueError, match="Division by zero is not allowed"):
IntDivide.execute(10, 0) Divide.execute(operands)
def test_division_large_numbers(self): def test_execute_division_by_zero_multiple_raises_error(self):
"""Test division with large numbers.""" """Test that division by zero in later operands raises error."""
result = IntDivide.execute(2000000, 1000) operands = {"operand0": 100, "operand1": 2, "operand2": 0}
assert result.value == 2000 with pytest.raises(ValueError, match="Division by zero is not allowed"):
Divide.execute(operands)
class TestMathNodesSchema: class TestMathNodesSchema:
"""Test suite for node schemas.""" """Test suite for node schemas."""
def test_int_add_schema(self): def test_add_schema(self):
"""Test IntAdd schema definition.""" """Test Add schema definition."""
schema = IntAdd.define_schema() schema = Add.define_schema()
assert schema.node_id == "IntAdd" assert schema.node_id == "Add"
assert schema.display_name == "Int Add" assert schema.display_name == "Add"
assert schema.category == "math"
assert len(schema.inputs) == 2
assert len(schema.outputs) == 1
def test_int_subtract_schema(self):
"""Test IntSubtract schema definition."""
schema = IntSubtract.define_schema()
assert schema.node_id == "IntSubtract"
assert schema.display_name == "Int Subtract"
assert schema.category == "math" assert schema.category == "math"
def test_int_multiply_schema(self): def test_subtract_schema(self):
"""Test IntMultiply schema definition.""" """Test Subtract schema definition."""
schema = IntMultiply.define_schema() schema = Subtract.define_schema()
assert schema.node_id == "IntMultiply" assert schema.node_id == "Subtract"
assert schema.display_name == "Int Multiply" assert schema.display_name == "Subtract"
assert schema.category == "math" assert schema.category == "math"
def test_int_divide_schema(self): def test_multiply_schema(self):
"""Test IntDivide schema definition.""" """Test Multiply schema definition."""
schema = IntDivide.define_schema() schema = Multiply.define_schema()
assert schema.node_id == "IntDivide" assert schema.node_id == "Multiply"
assert schema.display_name == "Int Divide" assert schema.display_name == "Multiply"
assert schema.category == "math"
def test_divide_schema(self):
"""Test Divide schema definition."""
schema = Divide.define_schema()
assert schema.node_id == "Divide"
assert schema.display_name == "Divide"
assert schema.category == "math" assert schema.category == "math"
@ -194,7 +253,7 @@ class TestMathExtension:
node_list = await extension.get_node_list() node_list = await extension.get_node_list()
assert len(node_list) == 4 assert len(node_list) == 4
assert IntAdd in node_list assert Add in node_list
assert IntSubtract in node_list assert Subtract in node_list
assert IntMultiply in node_list assert Multiply in node_list
assert IntDivide in node_list assert Divide in node_list