mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 13:02:31 +08:00
Compare commits
4 Commits
8f2d5fe32a
...
59e8ddcf70
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59e8ddcf70 | ||
|
|
dcff27fe3f | ||
|
|
27f600cd8b | ||
|
|
c6f0f31d44 |
@ -236,6 +236,8 @@ class ComfyNodeABC(ABC):
|
|||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
DEPRECATED: bool
|
DEPRECATED: bool
|
||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
|
DEV_ONLY: bool
|
||||||
|
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
|
||||||
API_NODE: Optional[bool]
|
API_NODE: Optional[bool]
|
||||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||||
|
|
||||||
|
|||||||
@ -1247,6 +1247,7 @@ class NodeInfoV1:
|
|||||||
output_node: bool=None
|
output_node: bool=None
|
||||||
deprecated: bool=None
|
deprecated: bool=None
|
||||||
experimental: bool=None
|
experimental: bool=None
|
||||||
|
dev_only: bool=None
|
||||||
api_node: bool=None
|
api_node: bool=None
|
||||||
price_badge: dict | None = None
|
price_badge: dict | None = None
|
||||||
search_aliases: list[str]=None
|
search_aliases: list[str]=None
|
||||||
@ -1264,6 +1265,7 @@ class NodeInfoV3:
|
|||||||
output_node: bool=None
|
output_node: bool=None
|
||||||
deprecated: bool=None
|
deprecated: bool=None
|
||||||
experimental: bool=None
|
experimental: bool=None
|
||||||
|
dev_only: bool=None
|
||||||
api_node: bool=None
|
api_node: bool=None
|
||||||
price_badge: dict | None = None
|
price_badge: dict | None = None
|
||||||
|
|
||||||
@ -1375,6 +1377,8 @@ class Schema:
|
|||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
is_experimental: bool=False
|
is_experimental: bool=False
|
||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
|
is_dev_only: bool=False
|
||||||
|
"""Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled."""
|
||||||
is_api_node: bool=False
|
is_api_node: bool=False
|
||||||
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
|
||||||
price_badge: PriceBadge | None = None
|
price_badge: PriceBadge | None = None
|
||||||
@ -1485,6 +1489,7 @@ class Schema:
|
|||||||
output_node=self.is_output_node,
|
output_node=self.is_output_node,
|
||||||
deprecated=self.is_deprecated,
|
deprecated=self.is_deprecated,
|
||||||
experimental=self.is_experimental,
|
experimental=self.is_experimental,
|
||||||
|
dev_only=self.is_dev_only,
|
||||||
api_node=self.is_api_node,
|
api_node=self.is_api_node,
|
||||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||||
@ -1519,6 +1524,7 @@ class Schema:
|
|||||||
output_node=self.is_output_node,
|
output_node=self.is_output_node,
|
||||||
deprecated=self.is_deprecated,
|
deprecated=self.is_deprecated,
|
||||||
experimental=self.is_experimental,
|
experimental=self.is_experimental,
|
||||||
|
dev_only=self.is_dev_only,
|
||||||
api_node=self.is_api_node,
|
api_node=self.is_api_node,
|
||||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||||
@ -1791,6 +1797,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
cls.GET_SCHEMA()
|
cls.GET_SCHEMA()
|
||||||
return cls._DEPRECATED
|
return cls._DEPRECATED
|
||||||
|
|
||||||
|
_DEV_ONLY = None
|
||||||
|
@final
|
||||||
|
@classproperty
|
||||||
|
def DEV_ONLY(cls): # noqa
|
||||||
|
if cls._DEV_ONLY is None:
|
||||||
|
cls.GET_SCHEMA()
|
||||||
|
return cls._DEV_ONLY
|
||||||
|
|
||||||
_API_NODE = None
|
_API_NODE = None
|
||||||
@final
|
@final
|
||||||
@classproperty
|
@classproperty
|
||||||
@ -1893,6 +1907,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
cls._EXPERIMENTAL = schema.is_experimental
|
cls._EXPERIMENTAL = schema.is_experimental
|
||||||
if cls._DEPRECATED is None:
|
if cls._DEPRECATED is None:
|
||||||
cls._DEPRECATED = schema.is_deprecated
|
cls._DEPRECATED = schema.is_deprecated
|
||||||
|
if cls._DEV_ONLY is None:
|
||||||
|
cls._DEV_ONLY = schema.is_dev_only
|
||||||
if cls._API_NODE is None:
|
if cls._API_NODE is None:
|
||||||
cls._API_NODE = schema.is_api_node
|
cls._API_NODE = schema.is_api_node
|
||||||
if cls._OUTPUT_NODE is None:
|
if cls._OUTPUT_NODE is None:
|
||||||
|
|||||||
153
comfy_extras/nodes_math.py
Normal file
153
comfy_extras/nodes_math.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class Add(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("add")
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("operand", template=template),
|
||||||
|
prefix="operand",
|
||||||
|
min=2,
|
||||||
|
max=10
|
||||||
|
)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Add",
|
||||||
|
display_name="Add",
|
||||||
|
category="math",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("operands", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
values = list(operands.values())
|
||||||
|
result = values[0]
|
||||||
|
for value in values[1:]:
|
||||||
|
result = result + value
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class Subtract(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("subtract")
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("operand", template=template),
|
||||||
|
prefix="operand",
|
||||||
|
min=2,
|
||||||
|
max=10
|
||||||
|
)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Subtract",
|
||||||
|
display_name="Subtract",
|
||||||
|
category="math",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("operands", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
values = list(operands.values())
|
||||||
|
result = values[0]
|
||||||
|
for value in values[1:]:
|
||||||
|
result = result - value
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class Multiply(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("multiply")
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("operand", template=template),
|
||||||
|
prefix="operand",
|
||||||
|
min=2,
|
||||||
|
max=10
|
||||||
|
)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Multiply",
|
||||||
|
display_name="Multiply",
|
||||||
|
category="math",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("operands", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
values = list(operands.values())
|
||||||
|
result = values[0]
|
||||||
|
for value in values[1:]:
|
||||||
|
result = result * value
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class Divide(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
template = io.MatchType.Template("divide")
|
||||||
|
autogrow_template = io.Autogrow.TemplatePrefix(
|
||||||
|
io.MatchType.Input("operand", template=template),
|
||||||
|
prefix="operand",
|
||||||
|
min=2,
|
||||||
|
max=10
|
||||||
|
)
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Divide",
|
||||||
|
display_name="Divide",
|
||||||
|
category="math",
|
||||||
|
inputs=[
|
||||||
|
io.Autogrow.Input("operands", template=autogrow_template)
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(template=template)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, operands: io.Autogrow.Type) -> bool:
|
||||||
|
values = list(operands.values())
|
||||||
|
# Check for division by zero in any divisor (all operands except the first)
|
||||||
|
for value in values[1:]:
|
||||||
|
if value == 0:
|
||||||
|
return "Division by zero is not allowed"
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, operands: io.Autogrow.Type) -> io.NodeOutput:
|
||||||
|
values = list(operands.values())
|
||||||
|
result = values[0]
|
||||||
|
for value in values[1:]:
|
||||||
|
if value == 0:
|
||||||
|
raise ValueError("Division by zero is not allowed")
|
||||||
|
result = result / value
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
|
class MathExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
Add,
|
||||||
|
Subtract,
|
||||||
|
Multiply,
|
||||||
|
Divide,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> MathExtension:
|
||||||
|
return MathExtension()
|
||||||
1
nodes.py
1
nodes.py
@ -2427,6 +2427,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
|
"nodes_math.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
"nodes_wanmove.py",
|
"nodes_wanmove.py",
|
||||||
|
|||||||
@ -679,6 +679,8 @@ class PromptServer():
|
|||||||
info['deprecated'] = True
|
info['deprecated'] = True
|
||||||
if getattr(obj_class, "EXPERIMENTAL", False):
|
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||||
info['experimental'] = True
|
info['experimental'] = True
|
||||||
|
if getattr(obj_class, "DEV_ONLY", False):
|
||||||
|
info['dev_only'] = True
|
||||||
|
|
||||||
if hasattr(obj_class, 'API_NODE'):
|
if hasattr(obj_class, 'API_NODE'):
|
||||||
info['api_node'] = obj_class.API_NODE
|
info['api_node'] = obj_class.API_NODE
|
||||||
|
|||||||
259
tests-unit/comfy_extras_test/nodes_math_test.py
Normal file
259
tests-unit/comfy_extras_test/nodes_math_test.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
import pytest
|
||||||
|
from comfy_extras.nodes_math import Add, Subtract, Multiply, Divide
|
||||||
|
from comfy_api.latest import io
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdd:
|
||||||
|
"""Test suite for Add node."""
|
||||||
|
|
||||||
|
def test_basic_addition_integers(self):
|
||||||
|
"""Test basic addition of two integers."""
|
||||||
|
operands = {"operand0": 5, "operand1": 3}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert result.value == 8
|
||||||
|
assert isinstance(result, io.NodeOutput)
|
||||||
|
|
||||||
|
def test_addition_multiple_integers(self):
|
||||||
|
"""Test addition with multiple integers."""
|
||||||
|
operands = {"operand0": 5, "operand1": 3, "operand2": 2}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert result.value == 10
|
||||||
|
|
||||||
|
def test_addition_with_zero(self):
|
||||||
|
"""Test addition with zero."""
|
||||||
|
operands = {"operand0": 10, "operand1": 0}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert result.value == 10
|
||||||
|
|
||||||
|
def test_addition_with_negative(self):
|
||||||
|
"""Test addition with negative numbers."""
|
||||||
|
operands = {"operand0": 5, "operand1": -3}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert result.value == 2
|
||||||
|
|
||||||
|
def test_addition_floats(self):
|
||||||
|
"""Test addition with floats."""
|
||||||
|
operands = {"operand0": 5.5, "operand1": 3.2}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert abs(result.value - 8.7) < 1e-10
|
||||||
|
|
||||||
|
def test_addition_strings(self):
|
||||||
|
"""Test addition with strings (concatenation)."""
|
||||||
|
operands = {"operand0": "Hello", "operand1": " World"}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert result.value == "Hello World"
|
||||||
|
|
||||||
|
def test_addition_lists(self):
|
||||||
|
"""Test addition with lists (like conditionings)."""
|
||||||
|
operands = {"operand0": [1, 2, 3], "operand1": [4, 5]}
|
||||||
|
result = Add.execute(operands)
|
||||||
|
assert result.value == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubtract:
|
||||||
|
"""Test suite for Subtract node."""
|
||||||
|
|
||||||
|
def test_basic_subtraction_integers(self):
|
||||||
|
"""Test basic subtraction of two integers."""
|
||||||
|
operands = {"operand0": 10, "operand1": 3}
|
||||||
|
result = Subtract.execute(operands)
|
||||||
|
assert result.value == 7
|
||||||
|
assert isinstance(result, io.NodeOutput)
|
||||||
|
|
||||||
|
def test_subtraction_multiple_integers(self):
|
||||||
|
"""Test subtraction with multiple integers."""
|
||||||
|
operands = {"operand0": 20, "operand1": 5, "operand2": 3}
|
||||||
|
result = Subtract.execute(operands)
|
||||||
|
assert result.value == 12
|
||||||
|
|
||||||
|
def test_subtraction_with_zero(self):
|
||||||
|
"""Test subtraction with zero."""
|
||||||
|
operands = {"operand0": 10, "operand1": 0}
|
||||||
|
result = Subtract.execute(operands)
|
||||||
|
assert result.value == 10
|
||||||
|
|
||||||
|
def test_subtraction_resulting_negative(self):
|
||||||
|
"""Test subtraction that results in negative number."""
|
||||||
|
operands = {"operand0": 3, "operand1": 10}
|
||||||
|
result = Subtract.execute(operands)
|
||||||
|
assert result.value == -7
|
||||||
|
|
||||||
|
def test_subtraction_with_negative(self):
|
||||||
|
"""Test subtraction with negative numbers."""
|
||||||
|
operands = {"operand0": 5, "operand1": -3}
|
||||||
|
result = Subtract.execute(operands)
|
||||||
|
assert result.value == 8
|
||||||
|
|
||||||
|
def test_subtraction_floats(self):
|
||||||
|
"""Test subtraction with floats."""
|
||||||
|
operands = {"operand0": 10.5, "operand1": 3.2}
|
||||||
|
result = Subtract.execute(operands)
|
||||||
|
assert abs(result.value - 7.3) < 1e-10
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiply:
|
||||||
|
"""Test suite for Multiply node."""
|
||||||
|
|
||||||
|
def test_basic_multiplication_integers(self):
|
||||||
|
"""Test basic multiplication of two integers."""
|
||||||
|
operands = {"operand0": 5, "operand1": 3}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert result.value == 15
|
||||||
|
assert isinstance(result, io.NodeOutput)
|
||||||
|
|
||||||
|
def test_multiplication_multiple_integers(self):
|
||||||
|
"""Test multiplication with multiple integers."""
|
||||||
|
operands = {"operand0": 2, "operand1": 3, "operand2": 4}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert result.value == 24
|
||||||
|
|
||||||
|
def test_multiplication_with_zero(self):
|
||||||
|
"""Test multiplication with zero."""
|
||||||
|
operands = {"operand0": 10, "operand1": 0}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert result.value == 0
|
||||||
|
|
||||||
|
def test_multiplication_with_one(self):
|
||||||
|
"""Test multiplication with one."""
|
||||||
|
operands = {"operand0": 10, "operand1": 1}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert result.value == 10
|
||||||
|
|
||||||
|
def test_multiplication_with_negative(self):
|
||||||
|
"""Test multiplication with negative numbers."""
|
||||||
|
operands = {"operand0": 5, "operand1": -3}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert result.value == -15
|
||||||
|
|
||||||
|
def test_multiplication_floats(self):
|
||||||
|
"""Test multiplication with floats."""
|
||||||
|
operands = {"operand0": 2.5, "operand1": 4.0}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert abs(result.value - 10.0) < 1e-10
|
||||||
|
|
||||||
|
def test_multiplication_string_repetition(self):
|
||||||
|
"""Test multiplication with string (repetition)."""
|
||||||
|
operands = {"operand0": "Hello", "operand1": 3}
|
||||||
|
result = Multiply.execute(operands)
|
||||||
|
assert result.value == "HelloHelloHello"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDivide:
|
||||||
|
"""Test suite for Divide node."""
|
||||||
|
|
||||||
|
def test_basic_division_integers(self):
|
||||||
|
"""Test basic division of two integers."""
|
||||||
|
operands = {"operand0": 10, "operand1": 2}
|
||||||
|
result = Divide.execute(operands)
|
||||||
|
assert result.value == 5.0 # Division returns float
|
||||||
|
assert isinstance(result, io.NodeOutput)
|
||||||
|
|
||||||
|
def test_division_multiple_integers(self):
|
||||||
|
"""Test division with multiple integers."""
|
||||||
|
operands = {"operand0": 100, "operand1": 2, "operand2": 5}
|
||||||
|
result = Divide.execute(operands)
|
||||||
|
assert result.value == 10.0
|
||||||
|
|
||||||
|
def test_division_with_remainder(self):
|
||||||
|
"""Test division that results in float."""
|
||||||
|
operands = {"operand0": 10, "operand1": 3}
|
||||||
|
result = Divide.execute(operands)
|
||||||
|
assert abs(result.value - 3.3333333333333335) < 1e-10
|
||||||
|
|
||||||
|
def test_division_with_negative(self):
|
||||||
|
"""Test division with negative numbers."""
|
||||||
|
operands = {"operand0": 10, "operand1": -2}
|
||||||
|
result = Divide.execute(operands)
|
||||||
|
assert result.value == -5.0
|
||||||
|
|
||||||
|
def test_division_by_one(self):
|
||||||
|
"""Test division by one."""
|
||||||
|
operands = {"operand0": 10, "operand1": 1}
|
||||||
|
result = Divide.execute(operands)
|
||||||
|
assert result.value == 10.0
|
||||||
|
|
||||||
|
def test_division_floats(self):
|
||||||
|
"""Test division with floats."""
|
||||||
|
operands = {"operand0": 10.5, "operand1": 2.5}
|
||||||
|
result = Divide.execute(operands)
|
||||||
|
assert abs(result.value - 4.2) < 1e-10
|
||||||
|
|
||||||
|
def test_validate_inputs_division_by_zero(self):
|
||||||
|
"""Test that division by zero is caught in validation."""
|
||||||
|
operands = {"operand0": 10, "operand1": 0}
|
||||||
|
validation_result = Divide.validate_inputs(operands)
|
||||||
|
assert validation_result == "Division by zero is not allowed"
|
||||||
|
|
||||||
|
def test_validate_inputs_division_by_zero_multiple(self):
|
||||||
|
"""Test that division by zero in later operands is caught."""
|
||||||
|
operands = {"operand0": 10, "operand1": 2, "operand2": 0}
|
||||||
|
validation_result = Divide.validate_inputs(operands)
|
||||||
|
assert validation_result == "Division by zero is not allowed"
|
||||||
|
|
||||||
|
def test_validate_inputs_normal_division(self):
|
||||||
|
"""Test that normal division passes validation."""
|
||||||
|
operands = {"operand0": 10, "operand1": 2}
|
||||||
|
validation_result = Divide.validate_inputs(operands)
|
||||||
|
assert validation_result is True
|
||||||
|
|
||||||
|
def test_execute_division_by_zero_raises_error(self):
|
||||||
|
"""Test that division by zero raises ValueError in execute."""
|
||||||
|
operands = {"operand0": 10, "operand1": 0}
|
||||||
|
with pytest.raises(ValueError, match="Division by zero is not allowed"):
|
||||||
|
Divide.execute(operands)
|
||||||
|
|
||||||
|
def test_execute_division_by_zero_multiple_raises_error(self):
|
||||||
|
"""Test that division by zero in later operands raises error."""
|
||||||
|
operands = {"operand0": 100, "operand1": 2, "operand2": 0}
|
||||||
|
with pytest.raises(ValueError, match="Division by zero is not allowed"):
|
||||||
|
Divide.execute(operands)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMathNodesSchema:
|
||||||
|
"""Test suite for node schemas."""
|
||||||
|
|
||||||
|
def test_add_schema(self):
|
||||||
|
"""Test Add schema definition."""
|
||||||
|
schema = Add.define_schema()
|
||||||
|
assert schema.node_id == "Add"
|
||||||
|
assert schema.display_name == "Add"
|
||||||
|
assert schema.category == "math"
|
||||||
|
|
||||||
|
def test_subtract_schema(self):
|
||||||
|
"""Test Subtract schema definition."""
|
||||||
|
schema = Subtract.define_schema()
|
||||||
|
assert schema.node_id == "Subtract"
|
||||||
|
assert schema.display_name == "Subtract"
|
||||||
|
assert schema.category == "math"
|
||||||
|
|
||||||
|
def test_multiply_schema(self):
|
||||||
|
"""Test Multiply schema definition."""
|
||||||
|
schema = Multiply.define_schema()
|
||||||
|
assert schema.node_id == "Multiply"
|
||||||
|
assert schema.display_name == "Multiply"
|
||||||
|
assert schema.category == "math"
|
||||||
|
|
||||||
|
def test_divide_schema(self):
|
||||||
|
"""Test Divide schema definition."""
|
||||||
|
schema = Divide.define_schema()
|
||||||
|
assert schema.node_id == "Divide"
|
||||||
|
assert schema.display_name == "Divide"
|
||||||
|
assert schema.category == "math"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMathExtension:
|
||||||
|
"""Test suite for MathExtension."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extension_node_list(self):
|
||||||
|
"""Test that MathExtension returns all math nodes."""
|
||||||
|
from comfy_extras.nodes_math import MathExtension
|
||||||
|
|
||||||
|
extension = MathExtension()
|
||||||
|
node_list = await extension.get_node_list()
|
||||||
|
|
||||||
|
assert len(node_list) == 4
|
||||||
|
assert Add in node_list
|
||||||
|
assert Subtract in node_list
|
||||||
|
assert Multiply in node_list
|
||||||
|
assert Divide in node_list
|
||||||
Loading…
Reference in New Issue
Block a user