fix: track expansion output consumers and gate expected_outputs behind LAZY_OUTPUTS
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run

Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
bigcat88 2026-07-02 17:38:02 +03:00
parent bcac499524
commit 9211284444
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
7 changed files with 178 additions and 26 deletions

View File

@ -1655,7 +1655,11 @@ class Schema:
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
Use this for nodes that can skip computing outputs that aren't connected downstream.
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
Check `comfy_execution.utils.is_output_needed(i)` inside execute() - False means output i is definitely unused
and safe to skip. Only nodes with this flag receive expected_outputs; all others see None.
Limitation: consumers must exist before this node runs - a subgraph expansion that
hand-builds a link to a pre-existing node's already-skipped output reads a stale value."""
def validate(self):
'''Validate the schema:

View File

@ -22,7 +22,10 @@ class NodeNotFoundError(Exception):
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
"""Get the set of output indices that are connected downstream.
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
Outputs NOT in this set are DEFINITELY not used and safe to skip
(see Schema.lazy_outputs for the one expansion-related limitation).
Includes input links and consumers registered via add_output_consumer.
"""
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
@ -35,6 +38,8 @@ class DynamicPrompt:
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
# Output sockets consumed outside of input links (subgraph expansions)
self._external_output_consumers = {}
self._expected_outputs_map = None
def get_node(self, node_id):
@ -69,19 +74,22 @@ class DynamicPrompt:
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def add_output_consumer(self, node_id, socket):
"""Record an output socket consumed outside of input links, e.g. a subgraph
expansion mapping its parent's output to this node's output."""
self._external_output_consumers.setdefault(node_id, set()).add(socket)
self._expected_outputs_map = None
def _build_expected_outputs_map(self):
result = {}
for node_id in self.all_node_ids():
try:
node_data = self.get_node(node_id)
except NodeNotFoundError:
continue
node_data = self.get_node(node_id)
for value in node_data.get("inputs", {}).values():
if is_link(value):
from_node_id, from_socket = value
if from_node_id not in result:
result[from_node_id] = set()
result[from_node_id].add(from_socket)
result.setdefault(from_node_id, set()).add(from_socket)
for node_id, sockets in self._external_output_consumers.items():
result.setdefault(node_id, set()).update(sockets)
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
def get_expected_outputs_map(self):

View File

@ -29,6 +29,10 @@ def is_output_needed(output_index: int) -> bool:
Returns True if the output might be used (should be computed).
Returns False if the output is definitely not connected (safe to skip).
Only meaningful for LAZY_OUTPUTS nodes; for all others expected_outputs is
None and this always returns True (skipping without the flag would not be
reflected in the cache key).
"""
ctx = get_executing_context()
if ctx is None or ctx.expected_outputs is None:

View File

@ -561,7 +561,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
if getattr(class_def, "LAZY_OUTPUTS", False):
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
else:
expected_outputs = None
try:
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
finally:
@ -620,6 +623,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
new_output_links.append((from_node_id, from_socket))
dynprompt.add_output_consumer(from_node_id, from_socket)
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:

View File

@ -148,6 +148,45 @@ class TestGetExpectedOutputsForNode:
assert expected == frozenset({0, 1})
class TestExternalOutputConsumers:
"""Tests for DynamicPrompt.add_output_consumer() — out-of-band consumers
(subgraph expansion output mappings) that have no input link in the prompt."""
def test_external_consumer_only(self):
"""A socket consumed only externally must appear in expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
}
dynprompt = DynamicPrompt(prompt)
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset()
dynprompt.add_output_consumer("1", 1)
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({1})
def test_external_consumer_merges_with_links(self):
"""External consumers merge with input-link consumers."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
dynprompt.add_output_consumer("1", 2)
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 2})
def test_external_consumer_invalidates_cached_map(self):
"""Registering after the map was built must invalidate the cache."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Build (and cache) the map first
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0})
dynprompt.add_output_consumer("1", 1)
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 1})
class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field."""

View File

@ -672,6 +672,46 @@ class TestExecution:
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
def test_expected_outputs_expansion_output_mapping(self, client: ComfyClient, builder: GraphBuilder):
"""A socket consumed only via an expansion's parent-output mapping must still
be in the inner LAZY_OUTPUTS node's expected_outputs (white, not black)."""
g = builder
expander = g.node("TestExpectedOutputsExpansion", height=80, width=80)
output = g.node("PreviewImage", images=expander.out(0))
result = client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 255, (
"Inner node skipped an output that is consumed via the expansion's "
"parent-output mapping (expected white, got black)"
)
def test_expected_outputs_requires_opt_in(self, client: ComfyClient, builder: GraphBuilder, server):
"""Nodes without LAZY_OUTPUTS must see expected_outputs=None: their cache key
ignores topology, so a skipped output would be served stale after rewiring."""
g = builder
node = g.node("TestExpectedOutputsNotOptedIn", height=96, width=96)
output0 = g.node("PreviewImage", images=node.out(0))
# Only output 0 connected: correct gating -> node sees None, computes all
result1 = client.run(g)
assert numpy.array(result1.get_images(output0)[0]).min() == 255
# Connect output 1: key unchanged -> cache hit must still serve correct data
output1 = g.node("PreviewImage", images=node.out(1))
result2 = client.run(g)
if server["should_cache_results"]:
assert not result2.did_run(node), "Node should be a cache hit (key ignores topology)"
images1 = result2.get_images(output1)
assert len(images1) == 1, "Should have 1 image for output1"
assert numpy.array(images1[0]).min() == 255, (
"Non-opted-in node observed expected_outputs and skipped output 1; "
"the stale skipped value was then served from cache"
)
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized
run_warmup(client)

View File

@ -6,7 +6,7 @@ from .tools import VariantSupport
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO
from comfy_execution.utils import get_executing_context
from comfy_execution.utils import get_executing_context, is_output_needed
class TestLazyMixImages:
@classmethod
@ -510,27 +510,76 @@ class TestExpectedOutputs:
CATEGORY = "_for_testing"
def execute(self, height, width):
ctx = get_executing_context()
# Default: assume all outputs are expected (backwards compatibility)
output0_expected = True
output1_expected = True
output2_expected = True
if ctx is not None and ctx.expected_outputs is not None:
output0_expected = 0 in ctx.expected_outputs
output1_expected = 1 in ctx.expected_outputs
output2_expected = 2 in ctx.expected_outputs
# Return white image if expected, black if not
# This allows tests to verify which outputs were expected via pixel values
white = torch.ones(1, height, width, 3)
black = torch.zeros(1, height, width, 3)
return (
white if output0_expected else black,
white if output1_expected else black,
white if output2_expected else black,
white if is_output_needed(0) else black,
white if is_output_needed(1) else black,
white if is_output_needed(2) else black,
)
class TestExpectedOutputsExpansion:
"""Expands into an inner LAZY_OUTPUTS node whose output 1 is consumed ONLY via
the parent-output mapping (no input link anywhere). If that mapping is not part
of the expected-outputs map, the inner node wrongly skips it -> black not white.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "_for_testing"
def execute(self, height, width):
g = GraphBuilder()
inner = g.node("TestExpectedOutputs", height=height, width=width)
return {"result": (inner.out(1),), "expand": g.finalize()}
class TestExpectedOutputsNotOptedIn:
"""Reads expected_outputs WITHOUT declaring LAZY_OUTPUTS; the executor must pass
None (such nodes have no cache-key protection against output rewiring). Outputs
are white when the node correctly sees None, otherwise they encode membership.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE")
RETURN_NAMES = ("output0", "output1")
FUNCTION = "execute"
CATEGORY = "_for_testing"
def execute(self, height, width):
# Raw context access (not is_output_needed): must distinguish None from a set
ctx = get_executing_context()
expected = ctx.expected_outputs if ctx is not None else None
white = torch.ones(1, height, width, 3)
black = torch.zeros(1, height, width, 3)
if expected is None:
return (white, white.clone())
return (
white if 0 in expected else black,
white if 1 in expected else black,
)
@ -551,6 +600,8 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestParallelSleep": TestParallelSleep,
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
"TestExpectedOutputs": TestExpectedOutputs,
"TestExpectedOutputsExpansion": TestExpectedOutputsExpansion,
"TestExpectedOutputsNotOptedIn": TestExpectedOutputsNotOptedIn,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
@ -570,4 +621,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestParallelSleep": "Test Parallel Sleep",
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
"TestExpectedOutputs": "Test Expected Outputs",
"TestExpectedOutputsExpansion": "Test Expected Outputs Expansion",
"TestExpectedOutputsNotOptedIn": "Test Expected Outputs Not Opted In",
}