mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-01 01:00:20 +08:00
Add integer math nodes (Add, Subtract, Multiply, Divide) - Fixes #11995
This commit is contained in:
parent
0fc15700be
commit
c6f0f31d44
103
comfy_extras/nodes_math.py
Normal file
103
comfy_extras/nodes_math.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class IntAdd(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="IntAdd",
|
||||
display_name="Int Add",
|
||||
category="math",
|
||||
inputs=[
|
||||
io.Int.Input("a"),
|
||||
io.Int.Input("b"),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, a: int, b: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(a + b)
|
||||
|
||||
|
||||
class IntSubtract(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="IntSubtract",
|
||||
display_name="Int Subtract",
|
||||
category="math",
|
||||
inputs=[
|
||||
io.Int.Input("a"),
|
||||
io.Int.Input("b"),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, a: int, b: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(a - b)
|
||||
|
||||
|
||||
class IntMultiply(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="IntMultiply",
|
||||
display_name="Int Multiply",
|
||||
category="math",
|
||||
inputs=[
|
||||
io.Int.Input("a"),
|
||||
io.Int.Input("b"),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, a: int, b: int) -> io.NodeOutput:
|
||||
return io.NodeOutput(a * b)
|
||||
|
||||
|
||||
class IntDivide(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="IntDivide",
|
||||
display_name="Int Divide",
|
||||
category="math",
|
||||
inputs=[
|
||||
io.Int.Input("a"),
|
||||
io.Int.Input("b"),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, a: int, b: int) -> bool:
|
||||
if b == 0:
|
||||
return "Division by zero is not allowed"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, a: int, b: int) -> io.NodeOutput:
|
||||
if b == 0:
|
||||
raise ValueError("Division by zero is not allowed")
|
||||
# Integer division (floor division)
|
||||
return io.NodeOutput(a // b)
|
||||
|
||||
|
||||
class MathExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
IntAdd,
|
||||
IntSubtract,
|
||||
IntMultiply,
|
||||
IntDivide,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MathExtension:
|
||||
return MathExtension()
|
||||
1
nodes.py
1
nodes.py
@ -2369,6 +2369,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_rope.py",
|
||||
"nodes_logic.py",
|
||||
"nodes_math.py",
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
|
||||
200
tests-unit/comfy_extras_test/nodes_math_test.py
Normal file
200
tests-unit/comfy_extras_test/nodes_math_test.py
Normal file
@ -0,0 +1,200 @@
|
||||
import pytest
|
||||
from comfy_extras.nodes_math import IntAdd, IntSubtract, IntMultiply, IntDivide
|
||||
from comfy_api.latest import io
|
||||
|
||||
|
||||
class TestIntAdd:
|
||||
"""Test suite for IntAdd node."""
|
||||
|
||||
def test_basic_addition(self):
|
||||
"""Test basic addition of two positive integers."""
|
||||
result = IntAdd.execute(5, 3)
|
||||
assert result.value == 8
|
||||
assert isinstance(result, io.NodeOutput)
|
||||
|
||||
def test_addition_with_zero(self):
|
||||
"""Test addition with zero."""
|
||||
result = IntAdd.execute(10, 0)
|
||||
assert result.value == 10
|
||||
|
||||
def test_addition_with_negative(self):
|
||||
"""Test addition with negative numbers."""
|
||||
result = IntAdd.execute(5, -3)
|
||||
assert result.value == 2
|
||||
|
||||
def test_addition_both_negative(self):
|
||||
"""Test addition of two negative numbers."""
|
||||
result = IntAdd.execute(-5, -3)
|
||||
assert result.value == -8
|
||||
|
||||
def test_addition_large_numbers(self):
|
||||
"""Test addition with large numbers."""
|
||||
result = IntAdd.execute(1000000, 2000000)
|
||||
assert result.value == 3000000
|
||||
|
||||
|
||||
class TestIntSubtract:
|
||||
"""Test suite for IntSubtract node."""
|
||||
|
||||
def test_basic_subtraction(self):
|
||||
"""Test basic subtraction of two positive integers."""
|
||||
result = IntSubtract.execute(10, 3)
|
||||
assert result.value == 7
|
||||
assert isinstance(result, io.NodeOutput)
|
||||
|
||||
def test_subtraction_with_zero(self):
|
||||
"""Test subtraction with zero."""
|
||||
result = IntSubtract.execute(10, 0)
|
||||
assert result.value == 10
|
||||
|
||||
def test_subtraction_resulting_negative(self):
|
||||
"""Test subtraction that results in negative number."""
|
||||
result = IntSubtract.execute(3, 10)
|
||||
assert result.value == -7
|
||||
|
||||
def test_subtraction_with_negative(self):
|
||||
"""Test subtraction with negative numbers."""
|
||||
result = IntSubtract.execute(5, -3)
|
||||
assert result.value == 8
|
||||
|
||||
def test_subtraction_both_negative(self):
|
||||
"""Test subtraction of two negative numbers."""
|
||||
result = IntSubtract.execute(-5, -3)
|
||||
assert result.value == -2
|
||||
|
||||
|
||||
class TestIntMultiply:
|
||||
"""Test suite for IntMultiply node."""
|
||||
|
||||
def test_basic_multiplication(self):
|
||||
"""Test basic multiplication of two positive integers."""
|
||||
result = IntMultiply.execute(5, 3)
|
||||
assert result.value == 15
|
||||
assert isinstance(result, io.NodeOutput)
|
||||
|
||||
def test_multiplication_with_zero(self):
|
||||
"""Test multiplication with zero."""
|
||||
result = IntMultiply.execute(10, 0)
|
||||
assert result.value == 0
|
||||
|
||||
def test_multiplication_with_one(self):
|
||||
"""Test multiplication with one."""
|
||||
result = IntMultiply.execute(10, 1)
|
||||
assert result.value == 10
|
||||
|
||||
def test_multiplication_with_negative(self):
|
||||
"""Test multiplication with negative numbers."""
|
||||
result = IntMultiply.execute(5, -3)
|
||||
assert result.value == -15
|
||||
|
||||
def test_multiplication_both_negative(self):
|
||||
"""Test multiplication of two negative numbers."""
|
||||
result = IntMultiply.execute(-5, -3)
|
||||
assert result.value == 15
|
||||
|
||||
def test_multiplication_large_numbers(self):
|
||||
"""Test multiplication with large numbers."""
|
||||
result = IntMultiply.execute(1000, 2000)
|
||||
assert result.value == 2000000
|
||||
|
||||
|
||||
class TestIntDivide:
|
||||
"""Test suite for IntDivide node."""
|
||||
|
||||
def test_basic_division(self):
|
||||
"""Test basic division of two positive integers."""
|
||||
result = IntDivide.execute(10, 2)
|
||||
assert result.value == 5
|
||||
assert isinstance(result, io.NodeOutput)
|
||||
|
||||
def test_division_with_remainder(self):
|
||||
"""Test division that results in floor division."""
|
||||
result = IntDivide.execute(10, 3)
|
||||
assert result.value == 3 # Floor division: 10 // 3 = 3
|
||||
|
||||
def test_division_with_negative(self):
|
||||
"""Test division with negative numbers."""
|
||||
result = IntDivide.execute(10, -3)
|
||||
assert result.value == -4 # Floor division: 10 // -3 = -4
|
||||
|
||||
def test_division_both_negative(self):
|
||||
"""Test division of two negative numbers."""
|
||||
result = IntDivide.execute(-10, -3)
|
||||
assert result.value == 3 # Floor division: -10 // -3 = 3
|
||||
|
||||
def test_division_by_one(self):
|
||||
"""Test division by one."""
|
||||
result = IntDivide.execute(10, 1)
|
||||
assert result.value == 10
|
||||
|
||||
def test_validate_inputs_division_by_zero(self):
|
||||
"""Test that division by zero is caught in validation."""
|
||||
validation_result = IntDivide.validate_inputs(10, 0)
|
||||
assert validation_result == "Division by zero is not allowed"
|
||||
|
||||
def test_validate_inputs_normal_division(self):
|
||||
"""Test that normal division passes validation."""
|
||||
validation_result = IntDivide.validate_inputs(10, 2)
|
||||
assert validation_result is True
|
||||
|
||||
def test_execute_division_by_zero_raises_error(self):
|
||||
"""Test that division by zero raises ValueError in execute."""
|
||||
with pytest.raises(ValueError, match="Division by zero is not allowed"):
|
||||
IntDivide.execute(10, 0)
|
||||
|
||||
def test_division_large_numbers(self):
|
||||
"""Test division with large numbers."""
|
||||
result = IntDivide.execute(2000000, 1000)
|
||||
assert result.value == 2000
|
||||
|
||||
|
||||
class TestMathNodesSchema:
|
||||
"""Test suite for node schemas."""
|
||||
|
||||
def test_int_add_schema(self):
|
||||
"""Test IntAdd schema definition."""
|
||||
schema = IntAdd.define_schema()
|
||||
assert schema.node_id == "IntAdd"
|
||||
assert schema.display_name == "Int Add"
|
||||
assert schema.category == "math"
|
||||
assert len(schema.inputs) == 2
|
||||
assert len(schema.outputs) == 1
|
||||
|
||||
def test_int_subtract_schema(self):
|
||||
"""Test IntSubtract schema definition."""
|
||||
schema = IntSubtract.define_schema()
|
||||
assert schema.node_id == "IntSubtract"
|
||||
assert schema.display_name == "Int Subtract"
|
||||
assert schema.category == "math"
|
||||
|
||||
def test_int_multiply_schema(self):
|
||||
"""Test IntMultiply schema definition."""
|
||||
schema = IntMultiply.define_schema()
|
||||
assert schema.node_id == "IntMultiply"
|
||||
assert schema.display_name == "Int Multiply"
|
||||
assert schema.category == "math"
|
||||
|
||||
def test_int_divide_schema(self):
|
||||
"""Test IntDivide schema definition."""
|
||||
schema = IntDivide.define_schema()
|
||||
assert schema.node_id == "IntDivide"
|
||||
assert schema.display_name == "Int Divide"
|
||||
assert schema.category == "math"
|
||||
|
||||
|
||||
class TestMathExtension:
|
||||
"""Test suite for MathExtension."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extension_node_list(self):
|
||||
"""Test that MathExtension returns all math nodes."""
|
||||
from comfy_extras.nodes_math import MathExtension
|
||||
|
||||
extension = MathExtension()
|
||||
node_list = await extension.get_node_list()
|
||||
|
||||
assert len(node_list) == 4
|
||||
assert IntAdd in node_list
|
||||
assert IntSubtract in node_list
|
||||
assert IntMultiply in node_list
|
||||
assert IntDivide in node_list
|
||||
Loading…
Reference in New Issue
Block a user