mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
fix: track expansion output consumers and gate expected_outputs behind LAZY_OUTPUTS
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
parent
bcac499524
commit
9211284444
@ -1655,7 +1655,11 @@ class Schema:
|
||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||
|
||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
|
||||
Check `comfy_execution.utils.is_output_needed(i)` inside execute() - False means output i is definitely unused
|
||||
and safe to skip. Only nodes with this flag receive expected_outputs; all others see None.
|
||||
|
||||
Limitation: consumers must exist before this node runs - a subgraph expansion that
|
||||
hand-builds a link to a pre-existing node's already-skipped output reads a stale value."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
|
||||
@ -22,7 +22,10 @@ class NodeNotFoundError(Exception):
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip
|
||||
(see Schema.lazy_outputs for the one expansion-related limitation).
|
||||
|
||||
Includes input links and consumers registered via add_output_consumer.
|
||||
"""
|
||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||
|
||||
@ -35,6 +38,8 @@ class DynamicPrompt:
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
# Output sockets consumed outside of input links (subgraph expansions)
|
||||
self._external_output_consumers = {}
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def get_node(self, node_id):
|
||||
@ -69,19 +74,22 @@ class DynamicPrompt:
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def add_output_consumer(self, node_id, socket):
|
||||
"""Record an output socket consumed outside of input links, e.g. a subgraph
|
||||
expansion mapping its parent's output to this node's output."""
|
||||
self._external_output_consumers.setdefault(node_id, set()).add(socket)
|
||||
self._expected_outputs_map = None
|
||||
|
||||
def _build_expected_outputs_map(self):
|
||||
result = {}
|
||||
for node_id in self.all_node_ids():
|
||||
try:
|
||||
node_data = self.get_node(node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
node_data = self.get_node(node_id)
|
||||
for value in node_data.get("inputs", {}).values():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id not in result:
|
||||
result[from_node_id] = set()
|
||||
result[from_node_id].add(from_socket)
|
||||
result.setdefault(from_node_id, set()).add(from_socket)
|
||||
for node_id, sockets in self._external_output_consumers.items():
|
||||
result.setdefault(node_id, set()).update(sockets)
|
||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||
|
||||
def get_expected_outputs_map(self):
|
||||
|
||||
@ -29,6 +29,10 @@ def is_output_needed(output_index: int) -> bool:
|
||||
|
||||
Returns True if the output might be used (should be computed).
|
||||
Returns False if the output is definitely not connected (safe to skip).
|
||||
|
||||
Only meaningful for LAZY_OUTPUTS nodes; for all others expected_outputs is
|
||||
None and this always returns True (skipping without the flag would not be
|
||||
reflected in the cache key).
|
||||
"""
|
||||
ctx = get_executing_context()
|
||||
if ctx is None or ctx.expected_outputs is None:
|
||||
|
||||
@ -561,7 +561,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
if getattr(class_def, "LAZY_OUTPUTS", False):
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
else:
|
||||
expected_outputs = None
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
finally:
|
||||
@ -620,6 +623,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if is_link(node_outputs[i]):
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
new_output_links.append((from_node_id, from_socket))
|
||||
dynprompt.add_output_consumer(from_node_id, from_socket)
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
|
||||
@ -148,6 +148,45 @@ class TestGetExpectedOutputsForNode:
|
||||
assert expected == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExternalOutputConsumers:
|
||||
"""Tests for DynamicPrompt.add_output_consumer() — out-of-band consumers
|
||||
(subgraph expansion output mappings) that have no input link in the prompt."""
|
||||
|
||||
def test_external_consumer_only(self):
|
||||
"""A socket consumed only externally must appear in expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset()
|
||||
|
||||
dynprompt.add_output_consumer("1", 1)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({1})
|
||||
|
||||
def test_external_consumer_merges_with_links(self):
|
||||
"""External consumers merge with input-link consumers."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
dynprompt.add_output_consumer("1", 2)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 2})
|
||||
|
||||
def test_external_consumer_invalidates_cached_map(self):
|
||||
"""Registering after the map was built must invalidate the cache."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
# Build (and cache) the map first
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0})
|
||||
|
||||
dynprompt.add_output_consumer("1", 1)
|
||||
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 1})
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext with expected_outputs field."""
|
||||
|
||||
|
||||
@ -672,6 +672,46 @@ class TestExecution:
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||
|
||||
def test_expected_outputs_expansion_output_mapping(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""A socket consumed only via an expansion's parent-output mapping must still
|
||||
be in the inner LAZY_OUTPUTS node's expected_outputs (white, not black)."""
|
||||
g = builder
|
||||
expander = g.node("TestExpectedOutputsExpansion", height=80, width=80)
|
||||
output = g.node("PreviewImage", images=expander.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 255, (
|
||||
"Inner node skipped an output that is consumed via the expansion's "
|
||||
"parent-output mapping (expected white, got black)"
|
||||
)
|
||||
|
||||
def test_expected_outputs_requires_opt_in(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Nodes without LAZY_OUTPUTS must see expected_outputs=None: their cache key
|
||||
ignores topology, so a skipped output would be served stale after rewiring."""
|
||||
g = builder
|
||||
node = g.node("TestExpectedOutputsNotOptedIn", height=96, width=96)
|
||||
output0 = g.node("PreviewImage", images=node.out(0))
|
||||
|
||||
# Only output 0 connected: correct gating -> node sees None, computes all
|
||||
result1 = client.run(g)
|
||||
assert numpy.array(result1.get_images(output0)[0]).min() == 255
|
||||
|
||||
# Connect output 1: key unchanged -> cache hit must still serve correct data
|
||||
output1 = g.node("PreviewImage", images=node.out(1))
|
||||
result2 = client.run(g)
|
||||
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(node), "Node should be a cache hit (key ignores topology)"
|
||||
images1 = result2.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert numpy.array(images1[0]).min() == 255, (
|
||||
"Non-opted-in node observed expected_outputs and skipped output 1; "
|
||||
"the stale skipped value was then served from cache"
|
||||
)
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
run_warmup(client)
|
||||
|
||||
@ -6,7 +6,7 @@ from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.utils import get_executing_context, is_output_needed
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@ -510,27 +510,76 @@ class TestExpectedOutputs:
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Default: assume all outputs are expected (backwards compatibility)
|
||||
output0_expected = True
|
||||
output1_expected = True
|
||||
output2_expected = True
|
||||
|
||||
if ctx is not None and ctx.expected_outputs is not None:
|
||||
output0_expected = 0 in ctx.expected_outputs
|
||||
output1_expected = 1 in ctx.expected_outputs
|
||||
output2_expected = 2 in ctx.expected_outputs
|
||||
|
||||
# Return white image if expected, black if not
|
||||
# This allows tests to verify which outputs were expected via pixel values
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
return (
|
||||
white if output0_expected else black,
|
||||
white if output1_expected else black,
|
||||
white if output2_expected else black,
|
||||
white if is_output_needed(0) else black,
|
||||
white if is_output_needed(1) else black,
|
||||
white if is_output_needed(2) else black,
|
||||
)
|
||||
|
||||
|
||||
class TestExpectedOutputsExpansion:
|
||||
"""Expands into an inner LAZY_OUTPUTS node whose output 1 is consumed ONLY via
|
||||
the parent-output mapping (no input link anywhere). If that mapping is not part
|
||||
of the expected-outputs map, the inner node wrongly skips it -> black not white.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
g = GraphBuilder()
|
||||
inner = g.node("TestExpectedOutputs", height=height, width=width)
|
||||
return {"result": (inner.out(1),), "expand": g.finalize()}
|
||||
|
||||
|
||||
class TestExpectedOutputsNotOptedIn:
|
||||
"""Reads expected_outputs WITHOUT declaring LAZY_OUTPUTS; the executor must pass
|
||||
None (such nodes have no cache-key protection against output rewiring). Outputs
|
||||
are white when the node correctly sees None, otherwise they encode membership.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
# Raw context access (not is_output_needed): must distinguish None from a set
|
||||
ctx = get_executing_context()
|
||||
expected = ctx.expected_outputs if ctx is not None else None
|
||||
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
if expected is None:
|
||||
return (white, white.clone())
|
||||
return (
|
||||
white if 0 in expected else black,
|
||||
white if 1 in expected else black,
|
||||
)
|
||||
|
||||
|
||||
@ -551,6 +600,8 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
"TestExpectedOutputs": TestExpectedOutputs,
|
||||
"TestExpectedOutputsExpansion": TestExpectedOutputsExpansion,
|
||||
"TestExpectedOutputsNotOptedIn": TestExpectedOutputsNotOptedIn,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -570,4 +621,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
"TestExpectedOutputs": "Test Expected Outputs",
|
||||
"TestExpectedOutputsExpansion": "Test Expected Outputs Expansion",
|
||||
"TestExpectedOutputsNotOptedIn": "Test Expected Outputs Not Opted In",
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user