From 27f600cd8bb06bd7a568399d78cbdcf2c7230b32 Mon Sep 17 00:00:00 2001 From: jackdevis Date: Wed, 21 Jan 2026 03:57:38 -0800 Subject: [PATCH] Refactor math nodes to use MatchType and Autogrow for generic operations --- comfy_extras/nodes_math.py | 134 +++++++--- .../comfy_extras_test/nodes_math_test.py | 253 +++++++++++------- 2 files changed, 248 insertions(+), 139 deletions(-) diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index b5e5a94bd..6b84498b7 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -3,99 +3,149 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io -class IntAdd(io.ComfyNode): +class Add(io.ComfyNode): @classmethod def define_schema(cls): + template = io.MatchType.Template("add") + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("operand", template=template), + prefix="operand", + min=2, + max=10 + ) return io.Schema( - node_id="IntAdd", - display_name="Int Add", + node_id="Add", + display_name="Add", category="math", inputs=[ - io.Int.Input("a"), - io.Int.Input("b"), + io.Autogrow.Input("operands", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(template=template) ], - outputs=[io.Int.Output()], ) @classmethod - def execute(cls, a: int, b: int) -> io.NodeOutput: - return io.NodeOutput(a + b) + def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput: + values = list(operands.values()) + result = values[0] + for value in values[1:]: + result = result + value + return io.NodeOutput(result) -class IntSubtract(io.ComfyNode): +class Subtract(io.ComfyNode): @classmethod def define_schema(cls): + template = io.MatchType.Template("subtract") + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("operand", template=template), + prefix="operand", + min=2, + max=10 + ) return io.Schema( - node_id="IntSubtract", - display_name="Int Subtract", + node_id="Subtract", + display_name="Subtract", category="math", inputs=[ - io.Int.Input("a"), - io.Int.Input("b"), + io.Autogrow.Input("operands", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(template=template) ], - outputs=[io.Int.Output()], ) @classmethod - def execute(cls, a: int, b: int) -> io.NodeOutput: - return io.NodeOutput(a - b) + def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput: + values = list(operands.values()) + result = values[0] + for value in values[1:]: + result = result - value + return io.NodeOutput(result) -class IntMultiply(io.ComfyNode): +class Multiply(io.ComfyNode): @classmethod def define_schema(cls): + template = io.MatchType.Template("multiply") + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("operand", template=template), + prefix="operand", + min=2, + max=10 + ) return io.Schema( - node_id="IntMultiply", - display_name="Int Multiply", + node_id="Multiply", + display_name="Multiply", category="math", inputs=[ - io.Int.Input("a"), - io.Int.Input("b"), + io.Autogrow.Input("operands", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(template=template) ], - outputs=[io.Int.Output()], ) @classmethod - def execute(cls, a: int, b: int) -> io.NodeOutput: - return io.NodeOutput(a * b) + def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput: + values = list(operands.values()) + result = values[0] + for value in values[1:]: + result = result * value + return io.NodeOutput(result) -class IntDivide(io.ComfyNode): +class Divide(io.ComfyNode): @classmethod def define_schema(cls): + template = io.MatchType.Template("divide") + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("operand", template=template), + prefix="operand", + min=2, + max=10 + ) return io.Schema( - node_id="IntDivide", - display_name="Int Divide", + node_id="Divide", + display_name="Divide", category="math", inputs=[ - io.Int.Input("a"), - io.Int.Input("b"), + io.Autogrow.Input("operands", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(template=template) ], - outputs=[io.Int.Output()], ) @classmethod - def validate_inputs(cls, a: int, b: int) -> bool: - if b == 0: - return "Division by zero is not allowed" + def validate_inputs(cls, operands: io.Autogrow.Type) -> bool: + values = list(operands.values()) + # Check for division by zero in any divisor (all operands except the first) + for value in values[1:]: + if value == 0: + return "Division by zero is not allowed" return True @classmethod - def execute(cls, a: int, b: int) -> io.NodeOutput: - if b == 0: - raise ValueError("Division by zero is not allowed") - # Integer division (floor division) - return io.NodeOutput(a // b) + def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput: + values = list(operands.values()) + result = values[0] + for value in values[1:]: + if value == 0: + raise ValueError("Division by zero is not allowed") + result = result / value + return io.NodeOutput(result) class MathExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - IntAdd, - IntSubtract, - IntMultiply, - IntDivide, + Add, + Subtract, + Multiply, + Divide, ] diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py index bcc92c1f5..79698afa9 100644 --- a/tests-unit/comfy_extras_test/nodes_math_test.py +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -1,184 +1,243 @@ import pytest -from comfy_extras.nodes_math import IntAdd, IntSubtract, IntMultiply, IntDivide +from comfy_extras.nodes_math import Add, Subtract, Multiply, Divide from comfy_api.latest import io -class TestIntAdd: - """Test suite for IntAdd node.""" +class TestAdd: + """Test suite for Add node.""" - def test_basic_addition(self): - """Test basic addition of two positive integers.""" - result = IntAdd.execute(5, 3) + def test_basic_addition_integers(self): + """Test basic addition of two integers.""" + operands = {"operand0": 5, "operand1": 3} + result = Add.execute(operands) assert result.value == 8 assert isinstance(result, io.NodeOutput) + def test_addition_multiple_integers(self): + """Test addition with multiple integers.""" + operands = {"operand0": 5, "operand1": 3, "operand2": 2} + result = Add.execute(operands) + assert result.value == 10 + def test_addition_with_zero(self): """Test addition with zero.""" - result = IntAdd.execute(10, 0) + operands = {"operand0": 10, "operand1": 0} + result = Add.execute(operands) assert result.value == 10 def test_addition_with_negative(self): """Test addition with negative numbers.""" - result = IntAdd.execute(5, -3) + operands = {"operand0": 5, "operand1": -3} + result = Add.execute(operands) assert result.value == 2 - def test_addition_both_negative(self): - """Test addition of two negative numbers.""" - result = IntAdd.execute(-5, -3) - assert result.value == -8 + def test_addition_floats(self): + """Test addition with floats.""" + operands = {"operand0": 5.5, "operand1": 3.2} + result = Add.execute(operands) + assert abs(result.value - 8.7) < 1e-10 - def test_addition_large_numbers(self): - """Test addition with large numbers.""" - result = IntAdd.execute(1000000, 2000000) - assert result.value == 3000000 + def test_addition_strings(self): + """Test addition with strings (concatenation).""" + operands = {"operand0": "Hello", "operand1": " World"} + result = Add.execute(operands) + assert result.value == "Hello World" + + def test_addition_lists(self): + """Test addition with lists (like conditionings).""" + operands = {"operand0": [1, 2, 3], "operand1": [4, 5]} + result = Add.execute(operands) + assert result.value == [1, 2, 3, 4, 5] -class TestIntSubtract: - """Test suite for IntSubtract node.""" +class TestSubtract: + """Test suite for Subtract node.""" - def test_basic_subtraction(self): - """Test basic subtraction of two positive integers.""" - result = IntSubtract.execute(10, 3) + def test_basic_subtraction_integers(self): + """Test basic subtraction of two integers.""" + operands = {"operand0": 10, "operand1": 3} + result = Subtract.execute(operands) assert result.value == 7 assert isinstance(result, io.NodeOutput) + def test_subtraction_multiple_integers(self): + """Test subtraction with multiple integers.""" + operands = {"operand0": 20, "operand1": 5, "operand2": 3} + result = Subtract.execute(operands) + assert result.value == 12 + def test_subtraction_with_zero(self): """Test subtraction with zero.""" - result = IntSubtract.execute(10, 0) + operands = {"operand0": 10, "operand1": 0} + result = Subtract.execute(operands) assert result.value == 10 def test_subtraction_resulting_negative(self): """Test subtraction that results in negative number.""" - result = IntSubtract.execute(3, 10) + operands = {"operand0": 3, "operand1": 10} + result = Subtract.execute(operands) assert result.value == -7 def test_subtraction_with_negative(self): """Test subtraction with negative numbers.""" - result = IntSubtract.execute(5, -3) + operands = {"operand0": 5, "operand1": -3} + result = Subtract.execute(operands) assert result.value == 8 - def test_subtraction_both_negative(self): - """Test subtraction of two negative numbers.""" - result = IntSubtract.execute(-5, -3) - assert result.value == -2 + def test_subtraction_floats(self): + """Test subtraction with floats.""" + operands = {"operand0": 10.5, "operand1": 3.2} + result = Subtract.execute(operands) + assert abs(result.value - 7.3) < 1e-10 -class TestIntMultiply: - """Test suite for IntMultiply node.""" +class TestMultiply: + """Test suite for Multiply node.""" - def test_basic_multiplication(self): - """Test basic multiplication of two positive integers.""" - result = IntMultiply.execute(5, 3) + def test_basic_multiplication_integers(self): + """Test basic multiplication of two integers.""" + operands = {"operand0": 5, "operand1": 3} + result = Multiply.execute(operands) assert result.value == 15 assert isinstance(result, io.NodeOutput) + def test_multiplication_multiple_integers(self): + """Test multiplication with multiple integers.""" + operands = {"operand0": 2, "operand1": 3, "operand2": 4} + result = Multiply.execute(operands) + assert result.value == 24 + def test_multiplication_with_zero(self): """Test multiplication with zero.""" - result = IntMultiply.execute(10, 0) + operands = {"operand0": 10, "operand1": 0} + result = Multiply.execute(operands) assert result.value == 0 def test_multiplication_with_one(self): """Test multiplication with one.""" - result = IntMultiply.execute(10, 1) + operands = {"operand0": 10, "operand1": 1} + result = Multiply.execute(operands) assert result.value == 10 def test_multiplication_with_negative(self): """Test multiplication with negative numbers.""" - result = IntMultiply.execute(5, -3) + operands = {"operand0": 5, "operand1": -3} + result = Multiply.execute(operands) assert result.value == -15 - def test_multiplication_both_negative(self): - """Test multiplication of two negative numbers.""" - result = IntMultiply.execute(-5, -3) - assert result.value == 15 + def test_multiplication_floats(self): + """Test multiplication with floats.""" + operands = {"operand0": 2.5, "operand1": 4.0} + result = Multiply.execute(operands) + assert abs(result.value - 10.0) < 1e-10 - def test_multiplication_large_numbers(self): - """Test multiplication with large numbers.""" - result = IntMultiply.execute(1000, 2000) - assert result.value == 2000000 + def test_multiplication_string_repetition(self): + """Test multiplication with string (repetition).""" + operands = {"operand0": "Hello", "operand1": 3} + result = Multiply.execute(operands) + assert result.value == "HelloHelloHello" -class TestIntDivide: - """Test suite for IntDivide node.""" +class TestDivide: + """Test suite for Divide node.""" - def test_basic_division(self): - """Test basic division of two positive integers.""" - result = IntDivide.execute(10, 2) - assert result.value == 5 + def test_basic_division_integers(self): + """Test basic division of two integers.""" + operands = {"operand0": 10, "operand1": 2} + result = Divide.execute(operands) + assert result.value == 5.0 # Division returns float assert isinstance(result, io.NodeOutput) + def test_division_multiple_integers(self): + """Test division with multiple integers.""" + operands = {"operand0": 100, "operand1": 2, "operand2": 5} + result = Divide.execute(operands) + assert result.value == 10.0 + def test_division_with_remainder(self): - """Test division that results in floor division.""" - result = IntDivide.execute(10, 3) - assert result.value == 3 # Floor division: 10 // 3 = 3 + """Test division that results in float.""" + operands = {"operand0": 10, "operand1": 3} + result = Divide.execute(operands) + assert abs(result.value - 3.3333333333333335) < 1e-10 def test_division_with_negative(self): """Test division with negative numbers.""" - result = IntDivide.execute(10, -3) - assert result.value == -4 # Floor division: 10 // -3 = -4 - - def test_division_both_negative(self): - """Test division of two negative numbers.""" - result = IntDivide.execute(-10, -3) - assert result.value == 3 # Floor division: -10 // -3 = 3 + operands = {"operand0": 10, "operand1": -2} + result = Divide.execute(operands) + assert result.value == -5.0 def test_division_by_one(self): """Test division by one.""" - result = IntDivide.execute(10, 1) - assert result.value == 10 + operands = {"operand0": 10, "operand1": 1} + result = Divide.execute(operands) + assert result.value == 10.0 + + def test_division_floats(self): + """Test division with floats.""" + operands = {"operand0": 10.5, "operand1": 2.5} + result = Divide.execute(operands) + assert abs(result.value - 4.2) < 1e-10 def test_validate_inputs_division_by_zero(self): """Test that division by zero is caught in validation.""" - validation_result = IntDivide.validate_inputs(10, 0) + operands = {"operand0": 10, "operand1": 0} + validation_result = Divide.validate_inputs(operands) + assert validation_result == "Division by zero is not allowed" + + def test_validate_inputs_division_by_zero_multiple(self): + """Test that division by zero in later operands is caught.""" + operands = {"operand0": 10, "operand1": 2, "operand2": 0} + validation_result = Divide.validate_inputs(operands) assert validation_result == "Division by zero is not allowed" def test_validate_inputs_normal_division(self): """Test that normal division passes validation.""" - validation_result = IntDivide.validate_inputs(10, 2) + operands = {"operand0": 10, "operand1": 2} + validation_result = Divide.validate_inputs(operands) assert validation_result is True def test_execute_division_by_zero_raises_error(self): """Test that division by zero raises ValueError in execute.""" + operands = {"operand0": 10, "operand1": 0} with pytest.raises(ValueError, match="Division by zero is not allowed"): - IntDivide.execute(10, 0) + Divide.execute(operands) - def test_division_large_numbers(self): - """Test division with large numbers.""" - result = IntDivide.execute(2000000, 1000) - assert result.value == 2000 + def test_execute_division_by_zero_multiple_raises_error(self): + """Test that division by zero in later operands raises error.""" + operands = {"operand0": 100, "operand1": 2, "operand2": 0} + with pytest.raises(ValueError, match="Division by zero is not allowed"): + Divide.execute(operands) class TestMathNodesSchema: """Test suite for node schemas.""" - def test_int_add_schema(self): - """Test IntAdd schema definition.""" - schema = IntAdd.define_schema() - assert schema.node_id == "IntAdd" - assert schema.display_name == "Int Add" - assert schema.category == "math" - assert len(schema.inputs) == 2 - assert len(schema.outputs) == 1 - - def test_int_subtract_schema(self): - """Test IntSubtract schema definition.""" - schema = IntSubtract.define_schema() - assert schema.node_id == "IntSubtract" - assert schema.display_name == "Int Subtract" + def test_add_schema(self): + """Test Add schema definition.""" + schema = Add.define_schema() + assert schema.node_id == "Add" + assert schema.display_name == "Add" assert schema.category == "math" - def test_int_multiply_schema(self): - """Test IntMultiply schema definition.""" - schema = IntMultiply.define_schema() - assert schema.node_id == "IntMultiply" - assert schema.display_name == "Int Multiply" + def test_subtract_schema(self): + """Test Subtract schema definition.""" + schema = Subtract.define_schema() + assert schema.node_id == "Subtract" + assert schema.display_name == "Subtract" assert schema.category == "math" - def test_int_divide_schema(self): - """Test IntDivide schema definition.""" - schema = IntDivide.define_schema() - assert schema.node_id == "IntDivide" - assert schema.display_name == "Int Divide" + def test_multiply_schema(self): + """Test Multiply schema definition.""" + schema = Multiply.define_schema() + assert schema.node_id == "Multiply" + assert schema.display_name == "Multiply" + assert schema.category == "math" + + def test_divide_schema(self): + """Test Divide schema definition.""" + schema = Divide.define_schema() + assert schema.node_id == "Divide" + assert schema.display_name == "Divide" assert schema.category == "math" @@ -194,7 +253,7 @@ class TestMathExtension: node_list = await extension.get_node_list() assert len(node_list) == 4 - assert IntAdd in node_list - assert IntSubtract in node_list - assert IntMultiply in node_list - assert IntDivide in node_list + assert Add in node_list + assert Subtract in node_list + assert Multiply in node_list + assert Divide in node_list