diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 5301b9529..dd83a91b3 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1655,7 +1655,11 @@ class Schema: """When True, cache will invalidate when output connections change, and expected_outputs will be available. Use this for nodes that can skip computing outputs that aren't connected downstream. - Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused.""" + Check `comfy_execution.utils.is_output_needed(i)` inside execute() - False means output i is definitely unused + and safe to skip. Only nodes with this flag receive expected_outputs; all others see None. + + Limitation: consumers must exist before this node runs - a subgraph expansion that + hand-builds a link to a pre-existing node's already-skipped output reads a stale value.""" def validate(self): '''Validate the schema: diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index ffd101871..efb5f1abf 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -22,7 +22,10 @@ class NodeNotFoundError(Exception): def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset: """Get the set of output indices that are connected downstream. Returns outputs that MIGHT be used. - Outputs NOT in this set are DEFINITELY not used and safe to skip. + Outputs NOT in this set are DEFINITELY not used and safe to skip + (see Schema.lazy_outputs for the one expansion-related limitation). + + Includes input links and consumers registered via add_output_consumer. """ return dynprompt.get_expected_outputs_map().get(node_id, frozenset()) @@ -35,6 +38,8 @@ class DynamicPrompt: self.ephemeral_prompt = {} self.ephemeral_parents = {} self.ephemeral_display = {} + # Output sockets consumed outside of input links (subgraph expansions) + self._external_output_consumers = {} self._expected_outputs_map = None def get_node(self, node_id): @@ -69,19 +74,22 @@ class DynamicPrompt: def all_node_ids(self): return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + def add_output_consumer(self, node_id, socket): + """Record an output socket consumed outside of input links, e.g. a subgraph + expansion mapping its parent's output to this node's output.""" + self._external_output_consumers.setdefault(node_id, set()).add(socket) + self._expected_outputs_map = None + def _build_expected_outputs_map(self): result = {} for node_id in self.all_node_ids(): - try: - node_data = self.get_node(node_id) - except NodeNotFoundError: - continue + node_data = self.get_node(node_id) for value in node_data.get("inputs", {}).values(): if is_link(value): from_node_id, from_socket = value - if from_node_id not in result: - result[from_node_id] = set() - result[from_node_id].add(from_socket) + result.setdefault(from_node_id, set()).add(from_socket) + for node_id, sockets in self._external_output_consumers.items(): + result.setdefault(node_id, set()).update(sockets) self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()} def get_expected_outputs_map(self): diff --git a/comfy_execution/utils.py b/comfy_execution/utils.py index 60e707881..e3a1d71c2 100644 --- a/comfy_execution/utils.py +++ b/comfy_execution/utils.py @@ -29,6 +29,10 @@ def is_output_needed(output_index: int) -> bool: Returns True if the output might be used (should be computed). Returns False if the output is definitely not connected (safe to skip). + + Only meaningful for LAZY_OUTPUTS nodes; for all others expected_outputs is + None and this always returns True (skipping without the flag would not be + reflected in the cache key). """ ctx = get_executing_context() if ctx is None or ctx.expected_outputs is None: diff --git a/execution.py b/execution.py index 9ea2eac88..f4c3c56a0 100644 --- a/execution.py +++ b/execution.py @@ -561,7 +561,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id) + if getattr(class_def, "LAZY_OUTPUTS", False): + expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id) + else: + expected_outputs = None try: output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs) finally: @@ -620,6 +623,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if is_link(node_outputs[i]): from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] new_output_links.append((from_node_id, from_socket)) + dynprompt.add_output_consumer(from_node_id, from_socket) cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: diff --git a/tests-unit/execution_test/expected_outputs_test.py b/tests-unit/execution_test/expected_outputs_test.py index b40786085..95d5538d5 100644 --- a/tests-unit/execution_test/expected_outputs_test.py +++ b/tests-unit/execution_test/expected_outputs_test.py @@ -148,6 +148,45 @@ class TestGetExpectedOutputsForNode: assert expected == frozenset({0, 1}) +class TestExternalOutputConsumers: + """Tests for DynamicPrompt.add_output_consumer() — out-of-band consumers + (subgraph expansion output mappings) that have no input link in the prompt.""" + + def test_external_consumer_only(self): + """A socket consumed only externally must appear in expected outputs.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + } + dynprompt = DynamicPrompt(prompt) + assert get_expected_outputs_for_node(dynprompt, "1") == frozenset() + + dynprompt.add_output_consumer("1", 1) + assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({1}) + + def test_external_consumer_merges_with_links(self): + """External consumers merge with input-link consumers.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + dynprompt.add_output_consumer("1", 2) + assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 2}) + + def test_external_consumer_invalidates_cached_map(self): + """Registering after the map was built must invalidate the cache.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + # Build (and cache) the map first + assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0}) + + dynprompt.add_output_consumer("1", 1) + assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 1}) + + class TestExecutionContext: """Tests for ExecutionContext with expected_outputs field.""" diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 6474326c6..86f00dbee 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -672,6 +672,46 @@ class TestExecution: assert numpy.array(images0[0]).min() == 255, "Output 0 should be white" assert numpy.array(images2[0]).min() == 255, "Output 2 should be white" + def test_expected_outputs_expansion_output_mapping(self, client: ComfyClient, builder: GraphBuilder): + """A socket consumed only via an expansion's parent-output mapping must still + be in the inner LAZY_OUTPUTS node's expected_outputs (white, not black).""" + g = builder + expander = g.node("TestExpectedOutputsExpansion", height=80, width=80) + output = g.node("PreviewImage", images=expander.out(0)) + + result = client.run(g) + + images = result.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 255, ( + "Inner node skipped an output that is consumed via the expansion's " + "parent-output mapping (expected white, got black)" + ) + + def test_expected_outputs_requires_opt_in(self, client: ComfyClient, builder: GraphBuilder, server): + """Nodes without LAZY_OUTPUTS must see expected_outputs=None: their cache key + ignores topology, so a skipped output would be served stale after rewiring.""" + g = builder + node = g.node("TestExpectedOutputsNotOptedIn", height=96, width=96) + output0 = g.node("PreviewImage", images=node.out(0)) + + # Only output 0 connected: correct gating -> node sees None, computes all + result1 = client.run(g) + assert numpy.array(result1.get_images(output0)[0]).min() == 255 + + # Connect output 1: key unchanged -> cache hit must still serve correct data + output1 = g.node("PreviewImage", images=node.out(1)) + result2 = client.run(g) + + if server["should_cache_results"]: + assert not result2.did_run(node), "Node should be a cache hit (key ignores topology)" + images1 = result2.get_images(output1) + assert len(images1) == 1, "Should have 1 image for output1" + assert numpy.array(images1[0]).min() == 255, ( + "Non-opted-in node observed expected_outputs and skipped output 1; " + "the stale skipped value was then served from cache" + ) + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized run_warmup(client) diff --git a/tests/execution/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py index c9aa0bb1e..07ce1147f 100644 --- a/tests/execution/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -6,7 +6,7 @@ from .tools import VariantSupport from comfy_execution.graph_utils import GraphBuilder from comfy.comfy_types.node_typing import ComfyNodeABC from comfy.comfy_types import IO -from comfy_execution.utils import get_executing_context +from comfy_execution.utils import get_executing_context, is_output_needed class TestLazyMixImages: @classmethod @@ -510,27 +510,76 @@ class TestExpectedOutputs: CATEGORY = "_for_testing" def execute(self, height, width): - ctx = get_executing_context() - - # Default: assume all outputs are expected (backwards compatibility) - output0_expected = True - output1_expected = True - output2_expected = True - - if ctx is not None and ctx.expected_outputs is not None: - output0_expected = 0 in ctx.expected_outputs - output1_expected = 1 in ctx.expected_outputs - output2_expected = 2 in ctx.expected_outputs - # Return white image if expected, black if not # This allows tests to verify which outputs were expected via pixel values white = torch.ones(1, height, width, 3) black = torch.zeros(1, height, width, 3) return ( - white if output0_expected else black, - white if output1_expected else black, - white if output2_expected else black, + white if is_output_needed(0) else black, + white if is_output_needed(1) else black, + white if is_output_needed(2) else black, + ) + + +class TestExpectedOutputsExpansion: + """Expands into an inner LAZY_OUTPUTS node whose output 1 is consumed ONLY via + the parent-output mapping (no input link anywhere). If that mapping is not part + of the expected-outputs map, the inner node wrongly skips it -> black not white. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "height": ("INT", {"default": 64, "min": 1, "max": 1024}), + "width": ("INT", {"default": 64, "min": 1, "max": 1024}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + CATEGORY = "_for_testing" + + def execute(self, height, width): + g = GraphBuilder() + inner = g.node("TestExpectedOutputs", height=height, width=width) + return {"result": (inner.out(1),), "expand": g.finalize()} + + +class TestExpectedOutputsNotOptedIn: + """Reads expected_outputs WITHOUT declaring LAZY_OUTPUTS; the executor must pass + None (such nodes have no cache-key protection against output rewiring). Outputs + are white when the node correctly sees None, otherwise they encode membership. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "height": ("INT", {"default": 64, "min": 1, "max": 1024}), + "width": ("INT", {"default": 64, "min": 1, "max": 1024}), + }, + } + + RETURN_TYPES = ("IMAGE", "IMAGE") + RETURN_NAMES = ("output0", "output1") + FUNCTION = "execute" + CATEGORY = "_for_testing" + + def execute(self, height, width): + # Raw context access (not is_output_needed): must distinguish None from a set + ctx = get_executing_context() + expected = ctx.expected_outputs if ctx is not None else None + + white = torch.ones(1, height, width, 3) + black = torch.zeros(1, height, width, 3) + + if expected is None: + return (white, white.clone()) + return ( + white if 0 in expected else black, + white if 1 in expected else black, ) @@ -551,6 +600,8 @@ TEST_NODE_CLASS_MAPPINGS = { "TestParallelSleep": TestParallelSleep, "TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput, "TestExpectedOutputs": TestExpectedOutputs, + "TestExpectedOutputsExpansion": TestExpectedOutputsExpansion, + "TestExpectedOutputsNotOptedIn": TestExpectedOutputsNotOptedIn, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { @@ -570,4 +621,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestParallelSleep": "Test Parallel Sleep", "TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output", "TestExpectedOutputs": "Test Expected Outputs", + "TestExpectedOutputsExpansion": "Test Expected Outputs Expansion", + "TestExpectedOutputsNotOptedIn": "Test Expected Outputs Not Opted In", }