ComfyUI/tests-unit/execution_test/expected_outputs_test.py
bigcat88 50975a7a0d
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
add a helper function for easy use
2026-02-05 09:05:42 +02:00

299 lines
12 KiB
Python

"""Unit tests for the expected_outputs feature.
This feature allows nodes to know at runtime which outputs are connected downstream,
enabling them to skip computing outputs that aren't needed.
"""
from comfy_api.latest import IO
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from comfy_execution.utils import (
CurrentNodeContext,
ExecutionContext,
get_executing_context,
is_output_needed,
)
class TestGetExpectedOutputsForNode:
"""Tests for get_expected_outputs_for_node() function."""
def test_single_output_connected(self):
"""Test node with single output connected to one downstream node."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_multiple_outputs_partial_connected(self):
"""Test node with multiple outputs, only some connected."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
# Output 1 is not connected
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 2})
assert 1 not in expected # Output 1 is definitely unused
def test_no_outputs_connected(self):
"""Test node with no outputs connected."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "OtherNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset()
def test_same_output_connected_multiple_times(self):
"""Test same output connected to multiple downstream nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_node_not_in_prompt(self):
"""Test getting expected outputs for a node not in the prompt."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "999")
assert expected == frozenset()
def test_chained_nodes(self):
"""Test expected outputs in a chain of nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1's output 0 is connected to node 2
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
assert expected_1 == frozenset({0})
# Node 2's output 0 is connected to node 3
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
assert expected_2 == frozenset({0})
# Node 3 has no downstream connections
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
assert expected_3 == frozenset()
def test_complex_graph(self):
"""Test expected outputs in a complex graph with multiple connections."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1 has outputs 0, 1, 2 all connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1, 2})
def test_constant_inputs_ignored(self):
"""Test that constant (non-link) inputs don't affect expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {
"class_type": "ConsumerNode",
"inputs": {
"image": ["1", 0],
"value": 42,
"name": "test",
},
},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field."""
def test_context_with_expected_outputs(self):
"""Test creating ExecutionContext with expected_outputs."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
)
assert ctx.prompt_id == "prompt-123"
assert ctx.node_id == "node-456"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 2})
def test_context_without_expected_outputs(self):
"""Test ExecutionContext defaults to None for expected_outputs."""
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
assert ctx.expected_outputs is None
def test_context_empty_expected_outputs(self):
"""Test ExecutionContext with empty expected_outputs set."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
)
assert ctx.expected_outputs == frozenset()
assert len(ctx.expected_outputs) == 0
class TestCurrentNodeContext:
"""Tests for CurrentNodeContext context manager with expected_outputs."""
def test_context_manager_with_expected_outputs(self):
"""Test CurrentNodeContext sets and resets context correctly."""
assert get_executing_context() is None
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
ctx = get_executing_context()
assert ctx is not None
assert ctx.prompt_id == "prompt-1"
assert ctx.node_id == "node-1"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 1})
assert get_executing_context() is None
def test_context_manager_without_expected_outputs(self):
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
with CurrentNodeContext("prompt-1", "node-1"):
ctx = get_executing_context()
assert ctx is not None
assert ctx.expected_outputs is None
def test_nested_context_managers(self):
"""Test nested CurrentNodeContext managers."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
ctx1 = get_executing_context()
assert ctx1.expected_outputs == frozenset({0})
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
ctx2 = get_executing_context()
assert ctx2.expected_outputs == frozenset({1, 2})
assert ctx2.node_id == "node-2"
# After inner context exits, should be back to outer context
ctx1_again = get_executing_context()
assert ctx1_again.expected_outputs == frozenset({0})
assert ctx1_again.node_id == "node-1"
def test_output_check_pattern(self):
"""Test the typical pattern nodes will use to check expected outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
ctx = get_executing_context()
# Typical usage pattern
if ctx and ctx.expected_outputs is not None:
should_compute_0 = 0 in ctx.expected_outputs
should_compute_1 = 1 in ctx.expected_outputs
should_compute_2 = 2 in ctx.expected_outputs
else:
# Fallback when info not available
should_compute_0 = should_compute_1 = should_compute_2 = True
assert should_compute_0 is True
assert should_compute_1 is False # Not in expected_outputs
assert should_compute_2 is True
class TestSchemaLazyOutputs:
"""Tests for lazy_outputs in V3 Schema."""
def test_schema_lazy_outputs_default(self):
"""Test that lazy_outputs defaults to False."""
schema = IO.Schema(
node_id="TestNode",
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is False
def test_schema_lazy_outputs_true(self):
"""Test setting lazy_outputs to True."""
schema = IO.Schema(
node_id="TestNode",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is True
def test_v3_node_lazy_outputs_property(self):
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
class TestNodeWithLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithLazyOutputs",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
def test_v3_node_lazy_outputs_default(self):
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithoutLazyOutputs",
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
class TestIsOutputNeeded:
"""Tests for is_output_needed() helper function."""
def test_output_needed_when_in_expected(self):
"""Test that output is needed when in expected_outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
assert is_output_needed(0) is True
assert is_output_needed(2) is True
def test_output_not_needed_when_not_in_expected(self):
"""Test that output is not needed when not in expected_outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
assert is_output_needed(1) is False
assert is_output_needed(3) is False
def test_output_needed_when_no_context(self):
"""Test that output is needed when no context."""
assert get_executing_context() is None
assert is_output_needed(0) is True
assert is_output_needed(1) is True
def test_output_needed_when_expected_outputs_is_none(self):
"""Test that output is needed when expected_outputs is None."""
with CurrentNodeContext("prompt-1", "node-1", 0, None):
assert is_output_needed(0) is True
assert is_output_needed(1) is True