feat: add expected_outputs feature for lazy output computation

This commit is contained in:
bigcat88 2026-02-04 14:26:16 +02:00
parent 2b70ab9ad0
commit d987b0d32d
8 changed files with 515 additions and 18 deletions

View File

@ -1430,6 +1430,11 @@ class Schema:
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" """Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
accept_all_inputs: bool=False accept_all_inputs: bool=False
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
lazy_outputs: bool=False
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
Use this for nodes that can skip computing outputs that aren't connected downstream.
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
def validate(self): def validate(self):
'''Validate the schema: '''Validate the schema:
@ -1875,6 +1880,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA() cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS return cls._ACCEPT_ALL_INPUTS
_LAZY_OUTPUTS = None
@final
@classproperty
def LAZY_OUTPUTS(cls): # noqa
if cls._LAZY_OUTPUTS is None:
cls.GET_SCHEMA()
return cls._LAZY_OUTPUTS
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls) -> dict[str, dict]: def INPUT_TYPES(cls) -> dict[str, dict]:
@ -1917,6 +1930,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._NOT_IDEMPOTENT = schema.not_idempotent cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None: if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._LAZY_OUTPUTS is None:
cls._LAZY_OUTPUTS = schema.lazy_outputs
if cls._RETURN_TYPES is None: if cls._RETURN_TYPES is None:
output = [] output = []

View File

@ -5,7 +5,7 @@ import psutil
import time import time
import torch import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import nodes import nodes
@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
signature = [class_type, await self.is_changed_cache.get(node_id)] signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
expected = get_expected_outputs_for_node(dynprompt, node_id)
signature.append(("expected_outputs", tuple(sorted(expected))))
inputs = node["inputs"] inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):
if is_link(inputs[key]): if is_link(inputs[key]):

View File

@ -19,6 +19,27 @@ class NodeInputError(Exception):
class NodeNotFoundError(Exception): class NodeNotFoundError(Exception):
pass pass
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
"""Get the set of output indices that are connected downstream.
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
"""
expected = set()
for other_node_id in dynprompt.all_node_ids():
try:
node_data = dynprompt.get_node(other_node_id)
except NodeNotFoundError:
continue
inputs = node_data.get("inputs", {})
for input_name, value in inputs.items():
if is_link(value):
from_node_id, from_socket = value
if from_node_id == node_id:
expected.add(from_socket)
return frozenset(expected)
class DynamicPrompt: class DynamicPrompt:
def __init__(self, original_prompt): def __init__(self, original_prompt):
# The original prompt provided by the user # The original prompt provided by the user

View File

@ -1,21 +1,26 @@
import contextvars import contextvars
from typing import Optional, NamedTuple from typing import NamedTuple, FrozenSet
class ExecutionContext(NamedTuple): class ExecutionContext(NamedTuple):
""" """
Context information about the currently executing node. Context information about the currently executing node.
Attributes: Attributes:
prompt_id: The ID of the current prompt execution
node_id: The ID of the currently executing node node_id: The ID of the currently executing node
list_index: The index in a list being processed (for operations on batches/lists) list_index: The index in a list being processed (for operations on batches/lists)
expected_outputs: Set of output indices that might be used downstream.
Outputs NOT in this set are definitely unused (safe to skip).
None means the information is not available.
""" """
prompt_id: str prompt_id: str
node_id: str node_id: str
list_index: Optional[int] list_index: int | None
expected_outputs: FrozenSet[int] | None = None
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None) current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
def get_executing_context() -> Optional[ExecutionContext]: def get_executing_context() -> ExecutionContext | None:
return current_executing_context.get(None) return current_executing_context.get(None)
class CurrentNodeContext: class CurrentNodeContext:
@ -25,15 +30,22 @@ class CurrentNodeContext:
Sets the current_executing_context on enter and resets it on exit. Sets the current_executing_context on enter and resets it on exit.
Example: Example:
with CurrentNodeContext(node_id="123", list_index=0): with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
# Code that should run with the current node context set # Code that should run with the current node context set
process_image() process_image()
""" """
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None): def __init__(
self,
prompt_id: str,
node_id: str,
list_index: int | None = None,
expected_outputs: FrozenSet[int] | None = None,
):
self.context = ExecutionContext( self.context = ExecutionContext(
prompt_id= prompt_id, prompt_id=prompt_id,
node_id= node_id, node_id=node_id,
list_index= list_index list_index=list_index,
expected_outputs=expected_outputs,
) )
self.token = None self.token = None

View File

@ -31,6 +31,7 @@ from comfy_execution.graph import (
ExecutionBlocker, ExecutionBlocker,
ExecutionList, ExecutionList,
get_input_info, get_input_info,
get_expected_outputs_for_node,
) )
from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_execution.validation import validate_node_input from comfy_execution.validation import validate_node_input
@ -227,7 +228,18 @@ async def resolve_map_node_over_list_results(results):
raise exc raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results] return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): async def _async_map_node_over_list(
prompt_id,
unique_id,
obj,
input_data_all,
func,
allow_interrupt=False,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -277,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
else: else:
f = getattr(obj, func) f = getattr(obj, func)
if inspect.iscoroutinefunction(f): if inspect.iscoroutinefunction(f):
async def async_wrapper(f, prompt_id, unique_id, list_index, args): async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
with CurrentNodeContext(prompt_id, unique_id, list_index): with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
return await f(**args) return await f(**args)
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) task = asyncio.create_task(
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
)
# Give the task a chance to execute without yielding # Give the task a chance to execute without yielding
await asyncio.sleep(0) await asyncio.sleep(0)
if task.done(): if task.done():
@ -289,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
else: else:
results.append(task) results.append(task)
else: else:
with CurrentNodeContext(prompt_id, unique_id, index): with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
result = f(**inputs) result = f(**inputs)
results.append(result) results.append(result)
else: else:
@ -327,8 +341,17 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): async def get_output_data(
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) prompt_id,
unique_id,
obj,
input_data_all,
execution_block_cb=None,
pre_execute_cb=None,
v3_data=None,
expected_outputs=None,
):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task: if has_pending_task:
return return_values, {}, False, has_pending_task return return_values, {}, False, has_pending_task
@ -522,9 +545,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run. #that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator allocator = comfy.memory_management.aimdo_allocator
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
try: try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
finally: finally:
if allocator is not None: if allocator is not None:
comfy.model_management.reset_cast_buffers() comfy.model_management.reset_cast_buffers()

View File

@ -0,0 +1,269 @@
"""Unit tests for the expected_outputs feature.
This feature allows nodes to know at runtime which outputs are connected downstream,
enabling them to skip computing outputs that aren't needed.
"""
from comfy_api.latest import IO
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
from comfy_execution.utils import (
CurrentNodeContext,
ExecutionContext,
get_executing_context,
)
class TestGetExpectedOutputsForNode:
"""Tests for get_expected_outputs_for_node() function."""
def test_single_output_connected(self):
"""Test node with single output connected to one downstream node."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_multiple_outputs_partial_connected(self):
"""Test node with multiple outputs, only some connected."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
# Output 1 is not connected
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 2})
assert 1 not in expected # Output 1 is definitely unused
def test_no_outputs_connected(self):
"""Test node with no outputs connected."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "OtherNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset()
def test_same_output_connected_multiple_times(self):
"""Test same output connected to multiple downstream nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_node_not_in_prompt(self):
"""Test getting expected outputs for a node not in the prompt."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "999")
assert expected == frozenset()
def test_chained_nodes(self):
"""Test expected outputs in a chain of nodes."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1's output 0 is connected to node 2
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
assert expected_1 == frozenset({0})
# Node 2's output 0 is connected to node 3
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
assert expected_2 == frozenset({0})
# Node 3 has no downstream connections
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
assert expected_3 == frozenset()
def test_complex_graph(self):
"""Test expected outputs in a complex graph with multiple connections."""
prompt = {
"1": {"class_type": "MultiOutputNode", "inputs": {}},
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Node 1 has outputs 0, 1, 2 all connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1, 2})
def test_constant_inputs_ignored(self):
"""Test that constant (non-link) inputs don't affect expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {
"class_type": "ConsumerNode",
"inputs": {
"image": ["1", 0],
"value": 42,
"name": "test",
},
},
}
dynprompt = DynamicPrompt(prompt)
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field."""
def test_context_with_expected_outputs(self):
"""Test creating ExecutionContext with expected_outputs."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
)
assert ctx.prompt_id == "prompt-123"
assert ctx.node_id == "node-456"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 2})
def test_context_without_expected_outputs(self):
"""Test ExecutionContext defaults to None for expected_outputs."""
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
assert ctx.expected_outputs is None
def test_context_empty_expected_outputs(self):
"""Test ExecutionContext with empty expected_outputs set."""
ctx = ExecutionContext(
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
)
assert ctx.expected_outputs == frozenset()
assert len(ctx.expected_outputs) == 0
class TestCurrentNodeContext:
"""Tests for CurrentNodeContext context manager with expected_outputs."""
def test_context_manager_with_expected_outputs(self):
"""Test CurrentNodeContext sets and resets context correctly."""
assert get_executing_context() is None
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
ctx = get_executing_context()
assert ctx is not None
assert ctx.prompt_id == "prompt-1"
assert ctx.node_id == "node-1"
assert ctx.list_index == 0
assert ctx.expected_outputs == frozenset({0, 1})
assert get_executing_context() is None
def test_context_manager_without_expected_outputs(self):
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
with CurrentNodeContext("prompt-1", "node-1"):
ctx = get_executing_context()
assert ctx is not None
assert ctx.expected_outputs is None
def test_nested_context_managers(self):
"""Test nested CurrentNodeContext managers."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
ctx1 = get_executing_context()
assert ctx1.expected_outputs == frozenset({0})
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
ctx2 = get_executing_context()
assert ctx2.expected_outputs == frozenset({1, 2})
assert ctx2.node_id == "node-2"
# After inner context exits, should be back to outer context
ctx1_again = get_executing_context()
assert ctx1_again.expected_outputs == frozenset({0})
assert ctx1_again.node_id == "node-1"
def test_output_check_pattern(self):
"""Test the typical pattern nodes will use to check expected outputs."""
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
ctx = get_executing_context()
# Typical usage pattern
if ctx and ctx.expected_outputs is not None:
should_compute_0 = 0 in ctx.expected_outputs
should_compute_1 = 1 in ctx.expected_outputs
should_compute_2 = 2 in ctx.expected_outputs
else:
# Fallback when info not available
should_compute_0 = should_compute_1 = should_compute_2 = True
assert should_compute_0 is True
assert should_compute_1 is False # Not in expected_outputs
assert should_compute_2 is True
class TestSchemaLazyOutputs:
"""Tests for lazy_outputs in V3 Schema."""
def test_schema_lazy_outputs_default(self):
"""Test that lazy_outputs defaults to False."""
schema = IO.Schema(
node_id="TestNode",
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is False
def test_schema_lazy_outputs_true(self):
"""Test setting lazy_outputs to True."""
schema = IO.Schema(
node_id="TestNode",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
assert schema.lazy_outputs is True
def test_v3_node_lazy_outputs_property(self):
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
class TestNodeWithLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithLazyOutputs",
lazy_outputs=True,
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
def test_v3_node_lazy_outputs_default(self):
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TestNodeWithoutLazyOutputs",
inputs=[],
outputs=[IO.Float.Output()],
)
@classmethod
def execute(cls):
return IO.NodeOutput(1.0)
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False

View File

@ -574,6 +574,104 @@ class TestExecution:
else: else:
assert result.did_run(test_node), "The execution should have been re-run" assert result.did_run(test_node), "The execution should have been re-run"
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs contains all connected outputs."""
g = builder
# Create a node with 3 outputs, all connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Connect all 3 outputs to preview nodes
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result = client.run(g)
# All outputs should be white (255) since all are connected
images0 = result.get_images(output0)
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images0) == 1, "Should have 1 image for output0"
assert len(images1) == 1, "Should have 1 image for output1"
assert len(images2) == 1, "Should have 1 image for output2"
# White pixels = 255, meaning output was in expected_outputs
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs only contains connected outputs."""
g = builder
# Create a node with 3 outputs, only some connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Only connect outputs 0 and 2, leave output 1 disconnected
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
# output1 is intentionally not connected
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result = client.run(g)
# Connected outputs should be white (255)
images0 = result.get_images(output0)
images2 = result.get_images(output2)
assert len(images0) == 1, "Should have 1 image for output0"
assert len(images2) == 1, "Should have 1 image for output2"
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
"""Test that expected_outputs works with single connected output."""
g = builder
# Create a node with 3 outputs, only one connected
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
# Only connect output 1
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
result = client.run(g)
images1 = result.get_images(output1)
assert len(images1) == 1, "Should have 1 image for output1"
# Output 1 should be white (connected), others are not visible in this test
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
"""Test that cache invalidates when output connections change."""
g = builder
# Use unique dimensions to avoid cache collision with other expected_outputs tests
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
# First run: only connect output 0
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
result1 = client.run(g)
assert result1.did_run(expected_outputs_node), "First run should execute the node"
# Second run: same connections, should be cached
result2 = client.run(g)
if server["should_cache_results"]:
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
# Third run: add connection to output 2
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
result3 = client.run(g)
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
if server["should_cache_results"]:
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
# Verify both outputs are now white
images0 = result3.get_images(output0)
images2 = result3.get_images(output2)
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized # Warmup execution to ensure server is fully initialized

View File

@ -6,6 +6,7 @@ from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO from comfy.comfy_types import IO
from comfy_execution.utils import get_executing_context
class TestLazyMixImages: class TestLazyMixImages:
@classmethod @classmethod
@ -482,6 +483,57 @@ class TestOutputNodeWithSocketOutput:
result = image * value result = image * value
return (result,) return (result,)
class TestExpectedOutputs:
"""Test node for the expected_outputs feature.
This node has 3 IMAGE outputs that encode which outputs were expected:
- White image (255) if the output was in expected_outputs
- Black image (0) if the output was NOT in expected_outputs
This allows integration tests to verify which outputs were expected by checking pixel values.
"""
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
RETURN_NAMES = ("output0", "output1", "output2")
FUNCTION = "execute"
CATEGORY = "_for_testing"
def execute(self, height, width):
ctx = get_executing_context()
# Default: assume all outputs are expected (backwards compatibility)
output0_expected = True
output1_expected = True
output2_expected = True
if ctx is not None and ctx.expected_outputs is not None:
output0_expected = 0 in ctx.expected_outputs
output1_expected = 1 in ctx.expected_outputs
output2_expected = 2 in ctx.expected_outputs
# Return white image if expected, black if not
# This allows tests to verify which outputs were expected via pixel values
white = torch.ones(1, height, width, 3)
black = torch.zeros(1, height, width, 3)
return (
white if output0_expected else black,
white if output1_expected else black,
white if output2_expected else black,
)
TEST_NODE_CLASS_MAPPINGS = { TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
@ -498,6 +550,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestSleep": TestSleep, "TestSleep": TestSleep,
"TestParallelSleep": TestParallelSleep, "TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
"TestExpectedOutputs": TestExpectedOutputs,
} }
TEST_NODE_DISPLAY_NAME_MAPPINGS = { TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@ -516,4 +569,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestSleep": "Test Sleep", "TestSleep": "Test Sleep",
"TestParallelSleep": "Test Parallel Sleep", "TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
"TestExpectedOutputs": "Test Expected Outputs",
} }