From d987b0d32d2f0e7cc28af2b346effb5e6472eb2d Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Wed, 4 Feb 2026 14:26:16 +0200 Subject: [PATCH] feat: add expected_outputs feature for lazy output computation --- comfy_api/latest/_io.py | 15 + comfy_execution/caching.py | 6 +- comfy_execution/graph.py | 21 ++ comfy_execution/utils.py | 30 +- execution.py | 40 ++- .../execution_test/expected_outputs_test.py | 269 ++++++++++++++++++ tests/execution/test_execution.py | 98 +++++++ .../testing-pack/specific_tests.py | 54 ++++ 8 files changed, 515 insertions(+), 18 deletions(-) create mode 100644 tests-unit/execution_test/expected_outputs_test.py diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 93cf482ca..a08d91e83 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1430,6 +1430,11 @@ class Schema: """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" accept_all_inputs: bool=False """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" + lazy_outputs: bool=False + """When True, cache will invalidate when output connections change, and expected_outputs will be available. + + Use this for nodes that can skip computing outputs that aren't connected downstream. + Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused.""" def validate(self): '''Validate the schema: @@ -1875,6 +1880,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._ACCEPT_ALL_INPUTS + _LAZY_OUTPUTS = None + @final + @classproperty + def LAZY_OUTPUTS(cls): # noqa + if cls._LAZY_OUTPUTS is None: + cls.GET_SCHEMA() + return cls._LAZY_OUTPUTS + @final @classmethod def INPUT_TYPES(cls) -> dict[str, dict]: @@ -1917,6 +1930,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._NOT_IDEMPOTENT = schema.not_idempotent if cls._ACCEPT_ALL_INPUTS is None: cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs + if cls._LAZY_OUTPUTS is None: + cls._LAZY_OUTPUTS = schema.lazy_outputs if cls._RETURN_TYPES is None: output = [] diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..a2b66c95d 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -5,7 +5,7 @@ import psutil import time import torch from typing import Sequence, Mapping, Dict -from comfy_execution.graph import DynamicPrompt +from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node from abc import ABC, abstractmethod import nodes @@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet): signature = [class_type, await self.is_changed_cache.get(node_id)] if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) + # Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS + if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS: + expected = get_expected_outputs_for_node(dynprompt, node_id) + signature.append(("expected_outputs", tuple(sorted(expected)))) inputs = node["inputs"] for key in sorted(inputs.keys()): if is_link(inputs[key]): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 9d170b16e..689fd23fc 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -19,6 +19,27 @@ class NodeInputError(Exception): class NodeNotFoundError(Exception): pass + +def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset: + """Get the set of output indices that are connected downstream. + Returns outputs that MIGHT be used. + Outputs NOT in this set are DEFINITELY not used and safe to skip. + """ + expected = set() + for other_node_id in dynprompt.all_node_ids(): + try: + node_data = dynprompt.get_node(other_node_id) + except NodeNotFoundError: + continue + inputs = node_data.get("inputs", {}) + for input_name, value in inputs.items(): + if is_link(value): + from_node_id, from_socket = value + if from_node_id == node_id: + expected.add(from_socket) + return frozenset(expected) + + class DynamicPrompt: def __init__(self, original_prompt): # The original prompt provided by the user diff --git a/comfy_execution/utils.py b/comfy_execution/utils.py index 62d32f101..3bca3e995 100644 --- a/comfy_execution/utils.py +++ b/comfy_execution/utils.py @@ -1,21 +1,26 @@ import contextvars -from typing import Optional, NamedTuple +from typing import NamedTuple, FrozenSet class ExecutionContext(NamedTuple): """ Context information about the currently executing node. Attributes: + prompt_id: The ID of the current prompt execution node_id: The ID of the currently executing node list_index: The index in a list being processed (for operations on batches/lists) + expected_outputs: Set of output indices that might be used downstream. + Outputs NOT in this set are definitely unused (safe to skip). + None means the information is not available. """ prompt_id: str node_id: str - list_index: Optional[int] + list_index: int | None + expected_outputs: FrozenSet[int] | None = None -current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None) +current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None) -def get_executing_context() -> Optional[ExecutionContext]: +def get_executing_context() -> ExecutionContext | None: return current_executing_context.get(None) class CurrentNodeContext: @@ -25,15 +30,22 @@ class CurrentNodeContext: Sets the current_executing_context on enter and resets it on exit. Example: - with CurrentNodeContext(node_id="123", list_index=0): + with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0): # Code that should run with the current node context set process_image() """ - def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None): + def __init__( + self, + prompt_id: str, + node_id: str, + list_index: int | None = None, + expected_outputs: FrozenSet[int] | None = None, + ): self.context = ExecutionContext( - prompt_id= prompt_id, - node_id= node_id, - list_index= list_index + prompt_id=prompt_id, + node_id=node_id, + list_index=list_index, + expected_outputs=expected_outputs, ) self.token = None diff --git a/execution.py b/execution.py index 3dbab82e6..06cb17144 100644 --- a/execution.py +++ b/execution.py @@ -31,6 +31,7 @@ from comfy_execution.graph import ( ExecutionBlocker, ExecutionList, get_input_info, + get_expected_outputs_for_node, ) from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input @@ -227,7 +228,18 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): +async def _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + func, + allow_interrupt=False, + execution_block_cb=None, + pre_execute_cb=None, + v3_data=None, + expected_outputs=None, +): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -277,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f else: f = getattr(obj, func) if inspect.iscoroutinefunction(f): - async def async_wrapper(f, prompt_id, unique_id, list_index, args): - with CurrentNodeContext(prompt_id, unique_id, list_index): + async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs): + with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs): return await f(**args) - task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) + task = asyncio.create_task( + async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs) + ) # Give the task a chance to execute without yielding await asyncio.sleep(0) if task.done(): @@ -289,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f else: results.append(task) else: - with CurrentNodeContext(prompt_id, unique_id, index): + with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs): result = f(**inputs) results.append(result) else: @@ -327,8 +341,17 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) +async def get_output_data( + prompt_id, + unique_id, + obj, + input_data_all, + execution_block_cb=None, + pre_execute_cb=None, + v3_data=None, + expected_outputs=None, +): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -522,9 +545,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc #that we just want to cull out each model run. allocator = comfy.memory_management.aimdo_allocator + expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id) with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): try: - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs) finally: if allocator is not None: comfy.model_management.reset_cast_buffers() diff --git a/tests-unit/execution_test/expected_outputs_test.py b/tests-unit/execution_test/expected_outputs_test.py new file mode 100644 index 000000000..e3b610674 --- /dev/null +++ b/tests-unit/execution_test/expected_outputs_test.py @@ -0,0 +1,269 @@ +"""Unit tests for the expected_outputs feature. + +This feature allows nodes to know at runtime which outputs are connected downstream, +enabling them to skip computing outputs that aren't needed. +""" + +from comfy_api.latest import IO +from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node +from comfy_execution.utils import ( + CurrentNodeContext, + ExecutionContext, + get_executing_context, +) + + +class TestGetExpectedOutputsForNode: + """Tests for get_expected_outputs_for_node() function.""" + + def test_single_output_connected(self): + """Test node with single output connected to one downstream node.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0}) + + def test_multiple_outputs_partial_connected(self): + """Test node with multiple outputs, only some connected.""" + prompt = { + "1": {"class_type": "MultiOutputNode", "inputs": {}}, + "2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}}, + # Output 1 is not connected + "3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}}, + } + dynprompt = DynamicPrompt(prompt) + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0, 2}) + assert 1 not in expected # Output 1 is definitely unused + + def test_no_outputs_connected(self): + """Test node with no outputs connected.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "OtherNode", "inputs": {}}, + } + dynprompt = DynamicPrompt(prompt) + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset() + + def test_same_output_connected_multiple_times(self): + """Test same output connected to multiple downstream nodes.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}}, + "3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}}, + "4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0}) + + def test_node_not_in_prompt(self): + """Test getting expected outputs for a node not in the prompt.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + } + dynprompt = DynamicPrompt(prompt) + expected = get_expected_outputs_for_node(dynprompt, "999") + assert expected == frozenset() + + def test_chained_nodes(self): + """Test expected outputs in a chain of nodes.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}}, + "3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + + # Node 1's output 0 is connected to node 2 + expected_1 = get_expected_outputs_for_node(dynprompt, "1") + assert expected_1 == frozenset({0}) + + # Node 2's output 0 is connected to node 3 + expected_2 = get_expected_outputs_for_node(dynprompt, "2") + assert expected_2 == frozenset({0}) + + # Node 3 has no downstream connections + expected_3 = get_expected_outputs_for_node(dynprompt, "3") + assert expected_3 == frozenset() + + def test_complex_graph(self): + """Test expected outputs in a complex graph with multiple connections.""" + prompt = { + "1": {"class_type": "MultiOutputNode", "inputs": {}}, + "2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}}, + "3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}}, + "4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + + # Node 1 has outputs 0, 1, 2 all connected + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0, 1, 2}) + + def test_constant_inputs_ignored(self): + """Test that constant (non-link) inputs don't affect expected outputs.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": { + "class_type": "ConsumerNode", + "inputs": { + "image": ["1", 0], + "value": 42, + "name": "test", + }, + }, + } + dynprompt = DynamicPrompt(prompt) + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0}) + + +class TestExecutionContext: + """Tests for ExecutionContext with expected_outputs field.""" + + def test_context_with_expected_outputs(self): + """Test creating ExecutionContext with expected_outputs.""" + ctx = ExecutionContext( + prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2}) + ) + assert ctx.prompt_id == "prompt-123" + assert ctx.node_id == "node-456" + assert ctx.list_index == 0 + assert ctx.expected_outputs == frozenset({0, 2}) + + def test_context_without_expected_outputs(self): + """Test ExecutionContext defaults to None for expected_outputs.""" + ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0) + assert ctx.expected_outputs is None + + def test_context_empty_expected_outputs(self): + """Test ExecutionContext with empty expected_outputs set.""" + ctx = ExecutionContext( + prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset() + ) + assert ctx.expected_outputs == frozenset() + assert len(ctx.expected_outputs) == 0 + + +class TestCurrentNodeContext: + """Tests for CurrentNodeContext context manager with expected_outputs.""" + + def test_context_manager_with_expected_outputs(self): + """Test CurrentNodeContext sets and resets context correctly.""" + assert get_executing_context() is None + + with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})): + ctx = get_executing_context() + assert ctx is not None + assert ctx.prompt_id == "prompt-1" + assert ctx.node_id == "node-1" + assert ctx.list_index == 0 + assert ctx.expected_outputs == frozenset({0, 1}) + + assert get_executing_context() is None + + def test_context_manager_without_expected_outputs(self): + """Test CurrentNodeContext works without expected_outputs (backwards compatible).""" + with CurrentNodeContext("prompt-1", "node-1"): + ctx = get_executing_context() + assert ctx is not None + assert ctx.expected_outputs is None + + def test_nested_context_managers(self): + """Test nested CurrentNodeContext managers.""" + with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})): + ctx1 = get_executing_context() + assert ctx1.expected_outputs == frozenset({0}) + + with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})): + ctx2 = get_executing_context() + assert ctx2.expected_outputs == frozenset({1, 2}) + assert ctx2.node_id == "node-2" + + # After inner context exits, should be back to outer context + ctx1_again = get_executing_context() + assert ctx1_again.expected_outputs == frozenset({0}) + assert ctx1_again.node_id == "node-1" + + def test_output_check_pattern(self): + """Test the typical pattern nodes will use to check expected outputs.""" + with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})): + ctx = get_executing_context() + + # Typical usage pattern + if ctx and ctx.expected_outputs is not None: + should_compute_0 = 0 in ctx.expected_outputs + should_compute_1 = 1 in ctx.expected_outputs + should_compute_2 = 2 in ctx.expected_outputs + else: + # Fallback when info not available + should_compute_0 = should_compute_1 = should_compute_2 = True + + assert should_compute_0 is True + assert should_compute_1 is False # Not in expected_outputs + assert should_compute_2 is True + + +class TestSchemaLazyOutputs: + """Tests for lazy_outputs in V3 Schema.""" + + def test_schema_lazy_outputs_default(self): + """Test that lazy_outputs defaults to False.""" + schema = IO.Schema( + node_id="TestNode", + inputs=[], + outputs=[IO.Float.Output()], + ) + assert schema.lazy_outputs is False + + def test_schema_lazy_outputs_true(self): + """Test setting lazy_outputs to True.""" + schema = IO.Schema( + node_id="TestNode", + lazy_outputs=True, + inputs=[], + outputs=[IO.Float.Output()], + ) + assert schema.lazy_outputs is True + + def test_v3_node_lazy_outputs_property(self): + """Test that LAZY_OUTPUTS property works on V3 nodes.""" + + class TestNodeWithLazyOutputs(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TestNodeWithLazyOutputs", + lazy_outputs=True, + inputs=[], + outputs=[IO.Float.Output()], + ) + + @classmethod + def execute(cls): + return IO.NodeOutput(1.0) + + assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True + + def test_v3_node_lazy_outputs_default(self): + """Test that LAZY_OUTPUTS defaults to False on V3 nodes.""" + + class TestNodeWithoutLazyOutputs(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TestNodeWithoutLazyOutputs", + inputs=[], + outputs=[IO.Float.Output()], + ) + + @classmethod + def execute(cls): + return IO.NodeOutput(1.0) + + assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..3b5e1df20 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -574,6 +574,104 @@ class TestExecution: else: assert result.did_run(test_node), "The execution should have been re-run" + def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder): + """Test that expected_outputs contains all connected outputs.""" + g = builder + # Create a node with 3 outputs, all connected + expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64) + + # Connect all 3 outputs to preview nodes + output0 = g.node("PreviewImage", images=expected_outputs_node.out(0)) + output1 = g.node("PreviewImage", images=expected_outputs_node.out(1)) + output2 = g.node("PreviewImage", images=expected_outputs_node.out(2)) + + result = client.run(g) + + # All outputs should be white (255) since all are connected + images0 = result.get_images(output0) + images1 = result.get_images(output1) + images2 = result.get_images(output2) + + assert len(images0) == 1, "Should have 1 image for output0" + assert len(images1) == 1, "Should have 1 image for output1" + assert len(images2) == 1, "Should have 1 image for output2" + + # White pixels = 255, meaning output was in expected_outputs + assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)" + assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)" + assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)" + + def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder): + """Test that expected_outputs only contains connected outputs.""" + g = builder + # Create a node with 3 outputs, only some connected + expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64) + + # Only connect outputs 0 and 2, leave output 1 disconnected + output0 = g.node("PreviewImage", images=expected_outputs_node.out(0)) + # output1 is intentionally not connected + output2 = g.node("PreviewImage", images=expected_outputs_node.out(2)) + + result = client.run(g) + + # Connected outputs should be white (255) + images0 = result.get_images(output0) + images2 = result.get_images(output2) + + assert len(images0) == 1, "Should have 1 image for output0" + assert len(images2) == 1, "Should have 1 image for output2" + + # White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white + assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)" + assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)" + + def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder): + """Test that expected_outputs works with single connected output.""" + g = builder + # Create a node with 3 outputs, only one connected + expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64) + + # Only connect output 1 + output1 = g.node("PreviewImage", images=expected_outputs_node.out(1)) + + result = client.run(g) + + images1 = result.get_images(output1) + assert len(images1) == 1, "Should have 1 image for output1" + + # Output 1 should be white (connected), others are not visible in this test + assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)" + + def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server): + """Test that cache invalidates when output connections change.""" + g = builder + # Use unique dimensions to avoid cache collision with other expected_outputs tests + expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32) + + # First run: only connect output 0 + output0 = g.node("PreviewImage", images=expected_outputs_node.out(0)) + + result1 = client.run(g) + assert result1.did_run(expected_outputs_node), "First run should execute the node" + + # Second run: same connections, should be cached + result2 = client.run(g) + if server["should_cache_results"]: + assert not result2.did_run(expected_outputs_node), "Second run should be cached" + + # Third run: add connection to output 2 + output2 = g.node("PreviewImage", images=expected_outputs_node.out(2)) + + result3 = client.run(g) + # Because LAZY_OUTPUTS=True, changing connections should invalidate cache + if server["should_cache_results"]: + assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache" + + # Verify both outputs are now white + images0 = result3.get_images(output0) + images2 = result3.get_images(output2) + assert numpy.array(images0[0]).min() == 255, "Output 0 should be white" + assert numpy.array(images2[0]).min() == 255, "Output 2 should be white" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized diff --git a/tests/execution/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py index 4f8f01ae4..1910668e9 100644 --- a/tests/execution/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -6,6 +6,7 @@ from .tools import VariantSupport from comfy_execution.graph_utils import GraphBuilder from comfy.comfy_types.node_typing import ComfyNodeABC from comfy.comfy_types import IO +from comfy_execution.utils import get_executing_context class TestLazyMixImages: @classmethod @@ -482,6 +483,57 @@ class TestOutputNodeWithSocketOutput: result = image * value return (result,) + +class TestExpectedOutputs: + """Test node for the expected_outputs feature. + + This node has 3 IMAGE outputs that encode which outputs were expected: + - White image (255) if the output was in expected_outputs + - Black image (0) if the output was NOT in expected_outputs + + This allows integration tests to verify which outputs were expected by checking pixel values. + """ + LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "height": ("INT", {"default": 64, "min": 1, "max": 1024}), + "width": ("INT", {"default": 64, "min": 1, "max": 1024}), + }, + } + + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") + RETURN_NAMES = ("output0", "output1", "output2") + FUNCTION = "execute" + CATEGORY = "_for_testing" + + def execute(self, height, width): + ctx = get_executing_context() + + # Default: assume all outputs are expected (backwards compatibility) + output0_expected = True + output1_expected = True + output2_expected = True + + if ctx is not None and ctx.expected_outputs is not None: + output0_expected = 0 in ctx.expected_outputs + output1_expected = 1 in ctx.expected_outputs + output2_expected = 2 in ctx.expected_outputs + + # Return white image if expected, black if not + # This allows tests to verify which outputs were expected via pixel values + white = torch.ones(1, height, width, 3) + black = torch.zeros(1, height, width, 3) + + return ( + white if output0_expected else black, + white if output1_expected else black, + white if output2_expected else black, + ) + + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, @@ -498,6 +550,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestSleep": TestSleep, "TestParallelSleep": TestParallelSleep, "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, + "TestExpectedOutputs": TestExpectedOutputs, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -516,4 +569,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestSleep": "Test Sleep", "TestParallelSleep": "Test Parallel Sleep", "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", + "TestExpectedOutputs": "Test Expected Outputs", }