mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
fix: track expansion output consumers and gate expected_outputs behind LAZY_OUTPUTS
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
parent
bcac499524
commit
9211284444
@ -1655,7 +1655,11 @@ class Schema:
|
|||||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||||
|
|
||||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||||
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
|
Check `comfy_execution.utils.is_output_needed(i)` inside execute() - False means output i is definitely unused
|
||||||
|
and safe to skip. Only nodes with this flag receive expected_outputs; all others see None.
|
||||||
|
|
||||||
|
Limitation: consumers must exist before this node runs - a subgraph expansion that
|
||||||
|
hand-builds a link to a pre-existing node's already-skipped output reads a stale value."""
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
'''Validate the schema:
|
'''Validate the schema:
|
||||||
|
|||||||
@ -22,7 +22,10 @@ class NodeNotFoundError(Exception):
|
|||||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||||
"""Get the set of output indices that are connected downstream.
|
"""Get the set of output indices that are connected downstream.
|
||||||
Returns outputs that MIGHT be used.
|
Returns outputs that MIGHT be used.
|
||||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
Outputs NOT in this set are DEFINITELY not used and safe to skip
|
||||||
|
(see Schema.lazy_outputs for the one expansion-related limitation).
|
||||||
|
|
||||||
|
Includes input links and consumers registered via add_output_consumer.
|
||||||
"""
|
"""
|
||||||
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||||
|
|
||||||
@ -35,6 +38,8 @@ class DynamicPrompt:
|
|||||||
self.ephemeral_prompt = {}
|
self.ephemeral_prompt = {}
|
||||||
self.ephemeral_parents = {}
|
self.ephemeral_parents = {}
|
||||||
self.ephemeral_display = {}
|
self.ephemeral_display = {}
|
||||||
|
# Output sockets consumed outside of input links (subgraph expansions)
|
||||||
|
self._external_output_consumers = {}
|
||||||
self._expected_outputs_map = None
|
self._expected_outputs_map = None
|
||||||
|
|
||||||
def get_node(self, node_id):
|
def get_node(self, node_id):
|
||||||
@ -69,19 +74,22 @@ class DynamicPrompt:
|
|||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||||
|
|
||||||
|
def add_output_consumer(self, node_id, socket):
|
||||||
|
"""Record an output socket consumed outside of input links, e.g. a subgraph
|
||||||
|
expansion mapping its parent's output to this node's output."""
|
||||||
|
self._external_output_consumers.setdefault(node_id, set()).add(socket)
|
||||||
|
self._expected_outputs_map = None
|
||||||
|
|
||||||
def _build_expected_outputs_map(self):
|
def _build_expected_outputs_map(self):
|
||||||
result = {}
|
result = {}
|
||||||
for node_id in self.all_node_ids():
|
for node_id in self.all_node_ids():
|
||||||
try:
|
node_data = self.get_node(node_id)
|
||||||
node_data = self.get_node(node_id)
|
|
||||||
except NodeNotFoundError:
|
|
||||||
continue
|
|
||||||
for value in node_data.get("inputs", {}).values():
|
for value in node_data.get("inputs", {}).values():
|
||||||
if is_link(value):
|
if is_link(value):
|
||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if from_node_id not in result:
|
result.setdefault(from_node_id, set()).add(from_socket)
|
||||||
result[from_node_id] = set()
|
for node_id, sockets in self._external_output_consumers.items():
|
||||||
result[from_node_id].add(from_socket)
|
result.setdefault(node_id, set()).update(sockets)
|
||||||
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||||
|
|
||||||
def get_expected_outputs_map(self):
|
def get_expected_outputs_map(self):
|
||||||
|
|||||||
@ -29,6 +29,10 @@ def is_output_needed(output_index: int) -> bool:
|
|||||||
|
|
||||||
Returns True if the output might be used (should be computed).
|
Returns True if the output might be used (should be computed).
|
||||||
Returns False if the output is definitely not connected (safe to skip).
|
Returns False if the output is definitely not connected (safe to skip).
|
||||||
|
|
||||||
|
Only meaningful for LAZY_OUTPUTS nodes; for all others expected_outputs is
|
||||||
|
None and this always returns True (skipping without the flag would not be
|
||||||
|
reflected in the cache key).
|
||||||
"""
|
"""
|
||||||
ctx = get_executing_context()
|
ctx = get_executing_context()
|
||||||
if ctx is None or ctx.expected_outputs is None:
|
if ctx is None or ctx.expected_outputs is None:
|
||||||
|
|||||||
@ -561,7 +561,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||||
|
|
||||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
if getattr(class_def, "LAZY_OUTPUTS", False):
|
||||||
|
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||||
|
else:
|
||||||
|
expected_outputs = None
|
||||||
try:
|
try:
|
||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||||
finally:
|
finally:
|
||||||
@ -620,6 +623,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if is_link(node_outputs[i]):
|
if is_link(node_outputs[i]):
|
||||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||||
new_output_links.append((from_node_id, from_socket))
|
new_output_links.append((from_node_id, from_socket))
|
||||||
|
dynprompt.add_output_consumer(from_node_id, from_socket)
|
||||||
cached_outputs.append((True, node_outputs))
|
cached_outputs.append((True, node_outputs))
|
||||||
new_node_ids = set(new_node_ids)
|
new_node_ids = set(new_node_ids)
|
||||||
for cache in caches.all:
|
for cache in caches.all:
|
||||||
|
|||||||
@ -148,6 +148,45 @@ class TestGetExpectedOutputsForNode:
|
|||||||
assert expected == frozenset({0, 1})
|
assert expected == frozenset({0, 1})
|
||||||
|
|
||||||
|
|
||||||
|
class TestExternalOutputConsumers:
|
||||||
|
"""Tests for DynamicPrompt.add_output_consumer() — out-of-band consumers
|
||||||
|
(subgraph expansion output mappings) that have no input link in the prompt."""
|
||||||
|
|
||||||
|
def test_external_consumer_only(self):
|
||||||
|
"""A socket consumed only externally must appear in expected outputs."""
|
||||||
|
prompt = {
|
||||||
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||||
|
}
|
||||||
|
dynprompt = DynamicPrompt(prompt)
|
||||||
|
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset()
|
||||||
|
|
||||||
|
dynprompt.add_output_consumer("1", 1)
|
||||||
|
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({1})
|
||||||
|
|
||||||
|
def test_external_consumer_merges_with_links(self):
|
||||||
|
"""External consumers merge with input-link consumers."""
|
||||||
|
prompt = {
|
||||||
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||||
|
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||||
|
}
|
||||||
|
dynprompt = DynamicPrompt(prompt)
|
||||||
|
dynprompt.add_output_consumer("1", 2)
|
||||||
|
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 2})
|
||||||
|
|
||||||
|
def test_external_consumer_invalidates_cached_map(self):
|
||||||
|
"""Registering after the map was built must invalidate the cache."""
|
||||||
|
prompt = {
|
||||||
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||||
|
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||||
|
}
|
||||||
|
dynprompt = DynamicPrompt(prompt)
|
||||||
|
# Build (and cache) the map first
|
||||||
|
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0})
|
||||||
|
|
||||||
|
dynprompt.add_output_consumer("1", 1)
|
||||||
|
assert get_expected_outputs_for_node(dynprompt, "1") == frozenset({0, 1})
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionContext:
|
class TestExecutionContext:
|
||||||
"""Tests for ExecutionContext with expected_outputs field."""
|
"""Tests for ExecutionContext with expected_outputs field."""
|
||||||
|
|
||||||
|
|||||||
@ -672,6 +672,46 @@ class TestExecution:
|
|||||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||||
|
|
||||||
|
def test_expected_outputs_expansion_output_mapping(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
"""A socket consumed only via an expansion's parent-output mapping must still
|
||||||
|
be in the inner LAZY_OUTPUTS node's expected_outputs (white, not black)."""
|
||||||
|
g = builder
|
||||||
|
expander = g.node("TestExpectedOutputsExpansion", height=80, width=80)
|
||||||
|
output = g.node("PreviewImage", images=expander.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 255, (
|
||||||
|
"Inner node skipped an output that is consumed via the expansion's "
|
||||||
|
"parent-output mapping (expected white, got black)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_expected_outputs_requires_opt_in(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
|
"""Nodes without LAZY_OUTPUTS must see expected_outputs=None: their cache key
|
||||||
|
ignores topology, so a skipped output would be served stale after rewiring."""
|
||||||
|
g = builder
|
||||||
|
node = g.node("TestExpectedOutputsNotOptedIn", height=96, width=96)
|
||||||
|
output0 = g.node("PreviewImage", images=node.out(0))
|
||||||
|
|
||||||
|
# Only output 0 connected: correct gating -> node sees None, computes all
|
||||||
|
result1 = client.run(g)
|
||||||
|
assert numpy.array(result1.get_images(output0)[0]).min() == 255
|
||||||
|
|
||||||
|
# Connect output 1: key unchanged -> cache hit must still serve correct data
|
||||||
|
output1 = g.node("PreviewImage", images=node.out(1))
|
||||||
|
result2 = client.run(g)
|
||||||
|
|
||||||
|
if server["should_cache_results"]:
|
||||||
|
assert not result2.did_run(node), "Node should be a cache hit (key ignores topology)"
|
||||||
|
images1 = result2.get_images(output1)
|
||||||
|
assert len(images1) == 1, "Should have 1 image for output1"
|
||||||
|
assert numpy.array(images1[0]).min() == 255, (
|
||||||
|
"Non-opted-in node observed expected_outputs and skipped output 1; "
|
||||||
|
"the stale skipped value was then served from cache"
|
||||||
|
)
|
||||||
|
|
||||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||||
# Warmup execution to ensure server is fully initialized
|
# Warmup execution to ensure server is fully initialized
|
||||||
run_warmup(client)
|
run_warmup(client)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from .tools import VariantSupport
|
|||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||||
from comfy.comfy_types import IO
|
from comfy.comfy_types import IO
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context, is_output_needed
|
||||||
|
|
||||||
class TestLazyMixImages:
|
class TestLazyMixImages:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -510,27 +510,76 @@ class TestExpectedOutputs:
|
|||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def execute(self, height, width):
|
def execute(self, height, width):
|
||||||
ctx = get_executing_context()
|
|
||||||
|
|
||||||
# Default: assume all outputs are expected (backwards compatibility)
|
|
||||||
output0_expected = True
|
|
||||||
output1_expected = True
|
|
||||||
output2_expected = True
|
|
||||||
|
|
||||||
if ctx is not None and ctx.expected_outputs is not None:
|
|
||||||
output0_expected = 0 in ctx.expected_outputs
|
|
||||||
output1_expected = 1 in ctx.expected_outputs
|
|
||||||
output2_expected = 2 in ctx.expected_outputs
|
|
||||||
|
|
||||||
# Return white image if expected, black if not
|
# Return white image if expected, black if not
|
||||||
# This allows tests to verify which outputs were expected via pixel values
|
# This allows tests to verify which outputs were expected via pixel values
|
||||||
white = torch.ones(1, height, width, 3)
|
white = torch.ones(1, height, width, 3)
|
||||||
black = torch.zeros(1, height, width, 3)
|
black = torch.zeros(1, height, width, 3)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
white if output0_expected else black,
|
white if is_output_needed(0) else black,
|
||||||
white if output1_expected else black,
|
white if is_output_needed(1) else black,
|
||||||
white if output2_expected else black,
|
white if is_output_needed(2) else black,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpectedOutputsExpansion:
|
||||||
|
"""Expands into an inner LAZY_OUTPUTS node whose output 1 is consumed ONLY via
|
||||||
|
the parent-output mapping (no input link anywhere). If that mapping is not part
|
||||||
|
of the expected-outputs map, the inner node wrongly skips it -> black not white.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||||
|
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def execute(self, height, width):
|
||||||
|
g = GraphBuilder()
|
||||||
|
inner = g.node("TestExpectedOutputs", height=height, width=width)
|
||||||
|
return {"result": (inner.out(1),), "expand": g.finalize()}
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpectedOutputsNotOptedIn:
|
||||||
|
"""Reads expected_outputs WITHOUT declaring LAZY_OUTPUTS; the executor must pass
|
||||||
|
None (such nodes have no cache-key protection against output rewiring). Outputs
|
||||||
|
are white when the node correctly sees None, otherwise they encode membership.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||||
|
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||||
|
RETURN_NAMES = ("output0", "output1")
|
||||||
|
FUNCTION = "execute"
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def execute(self, height, width):
|
||||||
|
# Raw context access (not is_output_needed): must distinguish None from a set
|
||||||
|
ctx = get_executing_context()
|
||||||
|
expected = ctx.expected_outputs if ctx is not None else None
|
||||||
|
|
||||||
|
white = torch.ones(1, height, width, 3)
|
||||||
|
black = torch.zeros(1, height, width, 3)
|
||||||
|
|
||||||
|
if expected is None:
|
||||||
|
return (white, white.clone())
|
||||||
|
return (
|
||||||
|
white if 0 in expected else black,
|
||||||
|
white if 1 in expected else black,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -551,6 +600,8 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestParallelSleep": TestParallelSleep,
|
"TestParallelSleep": TestParallelSleep,
|
||||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||||
"TestExpectedOutputs": TestExpectedOutputs,
|
"TestExpectedOutputs": TestExpectedOutputs,
|
||||||
|
"TestExpectedOutputsExpansion": TestExpectedOutputsExpansion,
|
||||||
|
"TestExpectedOutputsNotOptedIn": TestExpectedOutputsNotOptedIn,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -570,4 +621,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestParallelSleep": "Test Parallel Sleep",
|
"TestParallelSleep": "Test Parallel Sleep",
|
||||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||||
"TestExpectedOutputs": "Test Expected Outputs",
|
"TestExpectedOutputs": "Test Expected Outputs",
|
||||||
|
"TestExpectedOutputsExpansion": "Test Expected Outputs Expansion",
|
||||||
|
"TestExpectedOutputsNotOptedIn": "Test Expected Outputs Not Opted In",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user