mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
299 lines
12 KiB
Python
299 lines
12 KiB
Python
"""Unit tests for the expected_outputs feature.
|
|
|
|
This feature allows nodes to know at runtime which outputs are connected downstream,
|
|
enabling them to skip computing outputs that aren't needed.
|
|
"""
|
|
|
|
from comfy_api.latest import IO
|
|
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
|
from comfy_execution.utils import (
|
|
CurrentNodeContext,
|
|
ExecutionContext,
|
|
get_executing_context,
|
|
is_output_needed,
|
|
)
|
|
|
|
|
|
class TestGetExpectedOutputsForNode:
|
|
"""Tests for get_expected_outputs_for_node() function."""
|
|
|
|
def test_single_output_connected(self):
|
|
"""Test node with single output connected to one downstream node."""
|
|
prompt = {
|
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
|
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected == frozenset({0})
|
|
|
|
def test_multiple_outputs_partial_connected(self):
|
|
"""Test node with multiple outputs, only some connected."""
|
|
prompt = {
|
|
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
|
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
|
# Output 1 is not connected
|
|
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected == frozenset({0, 2})
|
|
assert 1 not in expected # Output 1 is definitely unused
|
|
|
|
def test_no_outputs_connected(self):
|
|
"""Test node with no outputs connected."""
|
|
prompt = {
|
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
|
"2": {"class_type": "OtherNode", "inputs": {}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected == frozenset()
|
|
|
|
def test_same_output_connected_multiple_times(self):
|
|
"""Test same output connected to multiple downstream nodes."""
|
|
prompt = {
|
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
|
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
|
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
|
|
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected == frozenset({0})
|
|
|
|
def test_node_not_in_prompt(self):
|
|
"""Test getting expected outputs for a node not in the prompt."""
|
|
prompt = {
|
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
expected = get_expected_outputs_for_node(dynprompt, "999")
|
|
assert expected == frozenset()
|
|
|
|
def test_chained_nodes(self):
|
|
"""Test expected outputs in a chain of nodes."""
|
|
prompt = {
|
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
|
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
|
|
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
|
|
# Node 1's output 0 is connected to node 2
|
|
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected_1 == frozenset({0})
|
|
|
|
# Node 2's output 0 is connected to node 3
|
|
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
|
|
assert expected_2 == frozenset({0})
|
|
|
|
# Node 3 has no downstream connections
|
|
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
|
|
assert expected_3 == frozenset()
|
|
|
|
def test_complex_graph(self):
|
|
"""Test expected outputs in a complex graph with multiple connections."""
|
|
prompt = {
|
|
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
|
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
|
|
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
|
|
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
|
|
# Node 1 has outputs 0, 1, 2 all connected
|
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected == frozenset({0, 1, 2})
|
|
|
|
def test_constant_inputs_ignored(self):
|
|
"""Test that constant (non-link) inputs don't affect expected outputs."""
|
|
prompt = {
|
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
|
"2": {
|
|
"class_type": "ConsumerNode",
|
|
"inputs": {
|
|
"image": ["1", 0],
|
|
"value": 42,
|
|
"name": "test",
|
|
},
|
|
},
|
|
}
|
|
dynprompt = DynamicPrompt(prompt)
|
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
|
assert expected == frozenset({0})
|
|
|
|
|
|
class TestExecutionContext:
|
|
"""Tests for ExecutionContext with expected_outputs field."""
|
|
|
|
def test_context_with_expected_outputs(self):
|
|
"""Test creating ExecutionContext with expected_outputs."""
|
|
ctx = ExecutionContext(
|
|
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
|
|
)
|
|
assert ctx.prompt_id == "prompt-123"
|
|
assert ctx.node_id == "node-456"
|
|
assert ctx.list_index == 0
|
|
assert ctx.expected_outputs == frozenset({0, 2})
|
|
|
|
def test_context_without_expected_outputs(self):
|
|
"""Test ExecutionContext defaults to None for expected_outputs."""
|
|
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
|
|
assert ctx.expected_outputs is None
|
|
|
|
def test_context_empty_expected_outputs(self):
|
|
"""Test ExecutionContext with empty expected_outputs set."""
|
|
ctx = ExecutionContext(
|
|
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
|
|
)
|
|
assert ctx.expected_outputs == frozenset()
|
|
assert len(ctx.expected_outputs) == 0
|
|
|
|
|
|
class TestCurrentNodeContext:
|
|
"""Tests for CurrentNodeContext context manager with expected_outputs."""
|
|
|
|
def test_context_manager_with_expected_outputs(self):
|
|
"""Test CurrentNodeContext sets and resets context correctly."""
|
|
assert get_executing_context() is None
|
|
|
|
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
|
|
ctx = get_executing_context()
|
|
assert ctx is not None
|
|
assert ctx.prompt_id == "prompt-1"
|
|
assert ctx.node_id == "node-1"
|
|
assert ctx.list_index == 0
|
|
assert ctx.expected_outputs == frozenset({0, 1})
|
|
|
|
assert get_executing_context() is None
|
|
|
|
def test_context_manager_without_expected_outputs(self):
|
|
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
|
|
with CurrentNodeContext("prompt-1", "node-1"):
|
|
ctx = get_executing_context()
|
|
assert ctx is not None
|
|
assert ctx.expected_outputs is None
|
|
|
|
def test_nested_context_managers(self):
|
|
"""Test nested CurrentNodeContext managers."""
|
|
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
|
|
ctx1 = get_executing_context()
|
|
assert ctx1.expected_outputs == frozenset({0})
|
|
|
|
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
|
|
ctx2 = get_executing_context()
|
|
assert ctx2.expected_outputs == frozenset({1, 2})
|
|
assert ctx2.node_id == "node-2"
|
|
|
|
# After inner context exits, should be back to outer context
|
|
ctx1_again = get_executing_context()
|
|
assert ctx1_again.expected_outputs == frozenset({0})
|
|
assert ctx1_again.node_id == "node-1"
|
|
|
|
def test_output_check_pattern(self):
|
|
"""Test the typical pattern nodes will use to check expected outputs."""
|
|
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
|
ctx = get_executing_context()
|
|
|
|
# Typical usage pattern
|
|
if ctx and ctx.expected_outputs is not None:
|
|
should_compute_0 = 0 in ctx.expected_outputs
|
|
should_compute_1 = 1 in ctx.expected_outputs
|
|
should_compute_2 = 2 in ctx.expected_outputs
|
|
else:
|
|
# Fallback when info not available
|
|
should_compute_0 = should_compute_1 = should_compute_2 = True
|
|
|
|
assert should_compute_0 is True
|
|
assert should_compute_1 is False # Not in expected_outputs
|
|
assert should_compute_2 is True
|
|
|
|
|
|
class TestSchemaLazyOutputs:
|
|
"""Tests for lazy_outputs in V3 Schema."""
|
|
|
|
def test_schema_lazy_outputs_default(self):
|
|
"""Test that lazy_outputs defaults to False."""
|
|
schema = IO.Schema(
|
|
node_id="TestNode",
|
|
inputs=[],
|
|
outputs=[IO.Float.Output()],
|
|
)
|
|
assert schema.lazy_outputs is False
|
|
|
|
def test_schema_lazy_outputs_true(self):
|
|
"""Test setting lazy_outputs to True."""
|
|
schema = IO.Schema(
|
|
node_id="TestNode",
|
|
lazy_outputs=True,
|
|
inputs=[],
|
|
outputs=[IO.Float.Output()],
|
|
)
|
|
assert schema.lazy_outputs is True
|
|
|
|
def test_v3_node_lazy_outputs_property(self):
|
|
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
|
|
|
|
class TestNodeWithLazyOutputs(IO.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="TestNodeWithLazyOutputs",
|
|
lazy_outputs=True,
|
|
inputs=[],
|
|
outputs=[IO.Float.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls):
|
|
return IO.NodeOutput(1.0)
|
|
|
|
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
|
|
|
|
def test_v3_node_lazy_outputs_default(self):
|
|
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
|
|
|
|
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return IO.Schema(
|
|
node_id="TestNodeWithoutLazyOutputs",
|
|
inputs=[],
|
|
outputs=[IO.Float.Output()],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls):
|
|
return IO.NodeOutput(1.0)
|
|
|
|
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
|
|
|
|
|
|
class TestIsOutputNeeded:
|
|
"""Tests for is_output_needed() helper function."""
|
|
|
|
def test_output_needed_when_in_expected(self):
|
|
"""Test that output is needed when in expected_outputs."""
|
|
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
|
assert is_output_needed(0) is True
|
|
assert is_output_needed(2) is True
|
|
|
|
def test_output_not_needed_when_not_in_expected(self):
|
|
"""Test that output is not needed when not in expected_outputs."""
|
|
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
|
assert is_output_needed(1) is False
|
|
assert is_output_needed(3) is False
|
|
|
|
def test_output_needed_when_no_context(self):
|
|
"""Test that output is needed when no context."""
|
|
assert get_executing_context() is None
|
|
assert is_output_needed(0) is True
|
|
assert is_output_needed(1) is True
|
|
|
|
def test_output_needed_when_expected_outputs_is_none(self):
|
|
"""Test that output is needed when expected_outputs is None."""
|
|
with CurrentNodeContext("prompt-1", "node-1", 0, None):
|
|
assert is_output_needed(0) is True
|
|
assert is_output_needed(1) is True
|