mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 12:02:37 +08:00
feat: add expected_outputs feature for lazy output computation
This commit is contained in:
parent
2b70ab9ad0
commit
d987b0d32d
@ -1430,6 +1430,11 @@ class Schema:
|
||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||
accept_all_inputs: bool=False
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
lazy_outputs: bool=False
|
||||
"""When True, cache will invalidate when output connections change, and expected_outputs will be available.
|
||||
|
||||
Use this for nodes that can skip computing outputs that aren't connected downstream.
|
||||
Access via `get_executing_context().expected_outputs` - outputs NOT in the set are definitely unused."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@ -1875,6 +1880,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls.GET_SCHEMA()
|
||||
return cls._ACCEPT_ALL_INPUTS
|
||||
|
||||
_LAZY_OUTPUTS = None
|
||||
@final
|
||||
@classproperty
|
||||
def LAZY_OUTPUTS(cls): # noqa
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls.GET_SCHEMA()
|
||||
return cls._LAZY_OUTPUTS
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, dict]:
|
||||
@ -1917,6 +1930,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
cls._NOT_IDEMPOTENT = schema.not_idempotent
|
||||
if cls._ACCEPT_ALL_INPUTS is None:
|
||||
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
|
||||
if cls._LAZY_OUTPUTS is None:
|
||||
cls._LAZY_OUTPUTS = schema.lazy_outputs
|
||||
|
||||
if cls._RETURN_TYPES is None:
|
||||
output = []
|
||||
|
||||
@ -5,7 +5,7 @@ import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import nodes
|
||||
@ -115,6 +115,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
# Include expected_outputs in cache key for nodes that opt in via LAZY_OUTPUTS
|
||||
if hasattr(class_def, 'LAZY_OUTPUTS') and class_def.LAZY_OUTPUTS:
|
||||
expected = get_expected_outputs_for_node(dynprompt, node_id)
|
||||
signature.append(("expected_outputs", tuple(sorted(expected))))
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
|
||||
@ -19,6 +19,27 @@ class NodeInputError(Exception):
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
||||
"""Get the set of output indices that are connected downstream.
|
||||
Returns outputs that MIGHT be used.
|
||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||
"""
|
||||
expected = set()
|
||||
for other_node_id in dynprompt.all_node_ids():
|
||||
try:
|
||||
node_data = dynprompt.get_node(other_node_id)
|
||||
except NodeNotFoundError:
|
||||
continue
|
||||
inputs = node_data.get("inputs", {})
|
||||
for input_name, value in inputs.items():
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if from_node_id == node_id:
|
||||
expected.add(from_socket)
|
||||
return frozenset(expected)
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
|
||||
@ -1,21 +1,26 @@
|
||||
import contextvars
|
||||
from typing import Optional, NamedTuple
|
||||
from typing import NamedTuple, FrozenSet
|
||||
|
||||
class ExecutionContext(NamedTuple):
|
||||
"""
|
||||
Context information about the currently executing node.
|
||||
|
||||
Attributes:
|
||||
prompt_id: The ID of the current prompt execution
|
||||
node_id: The ID of the currently executing node
|
||||
list_index: The index in a list being processed (for operations on batches/lists)
|
||||
expected_outputs: Set of output indices that might be used downstream.
|
||||
Outputs NOT in this set are definitely unused (safe to skip).
|
||||
None means the information is not available.
|
||||
"""
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
list_index: Optional[int]
|
||||
list_index: int | None
|
||||
expected_outputs: FrozenSet[int] | None = None
|
||||
|
||||
current_executing_context: contextvars.ContextVar[Optional[ExecutionContext]] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
current_executing_context: contextvars.ContextVar[ExecutionContext | None] = contextvars.ContextVar("current_executing_context", default=None)
|
||||
|
||||
def get_executing_context() -> Optional[ExecutionContext]:
|
||||
def get_executing_context() -> ExecutionContext | None:
|
||||
return current_executing_context.get(None)
|
||||
|
||||
class CurrentNodeContext:
|
||||
@ -25,15 +30,22 @@ class CurrentNodeContext:
|
||||
Sets the current_executing_context on enter and resets it on exit.
|
||||
|
||||
Example:
|
||||
with CurrentNodeContext(node_id="123", list_index=0):
|
||||
with CurrentNodeContext(prompt_id="abc", node_id="123", list_index=0):
|
||||
# Code that should run with the current node context set
|
||||
process_image()
|
||||
"""
|
||||
def __init__(self, prompt_id: str, node_id: str, list_index: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_id: str,
|
||||
node_id: str,
|
||||
list_index: int | None = None,
|
||||
expected_outputs: FrozenSet[int] | None = None,
|
||||
):
|
||||
self.context = ExecutionContext(
|
||||
prompt_id= prompt_id,
|
||||
node_id= node_id,
|
||||
list_index= list_index
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index,
|
||||
expected_outputs=expected_outputs,
|
||||
)
|
||||
self.token = None
|
||||
|
||||
|
||||
40
execution.py
40
execution.py
@ -31,6 +31,7 @@ from comfy_execution.graph import (
|
||||
ExecutionBlocker,
|
||||
ExecutionList,
|
||||
get_input_info,
|
||||
get_expected_outputs_for_node,
|
||||
)
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
@ -227,7 +228,18 @@ async def resolve_map_node_over_list_results(results):
|
||||
raise exc
|
||||
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
|
||||
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
async def _async_map_node_over_list(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
func,
|
||||
allow_interrupt=False,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
# check if node wants the lists
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
@ -277,10 +289,12 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
else:
|
||||
f = getattr(obj, func)
|
||||
if inspect.iscoroutinefunction(f):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index):
|
||||
async def async_wrapper(f, prompt_id, unique_id, list_index, args, expected_outputs):
|
||||
with CurrentNodeContext(prompt_id, unique_id, list_index, expected_outputs):
|
||||
return await f(**args)
|
||||
task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs))
|
||||
task = asyncio.create_task(
|
||||
async_wrapper(f, prompt_id, unique_id, index, args=inputs, expected_outputs=expected_outputs)
|
||||
)
|
||||
# Give the task a chance to execute without yielding
|
||||
await asyncio.sleep(0)
|
||||
if task.done():
|
||||
@ -289,7 +303,7 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
||||
else:
|
||||
results.append(task)
|
||||
else:
|
||||
with CurrentNodeContext(prompt_id, unique_id, index):
|
||||
with CurrentNodeContext(prompt_id, unique_id, index, expected_outputs):
|
||||
result = f(**inputs)
|
||||
results.append(result)
|
||||
else:
|
||||
@ -327,8 +341,17 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
async def get_output_data(
|
||||
prompt_id,
|
||||
unique_id,
|
||||
obj,
|
||||
input_data_all,
|
||||
execution_block_cb=None,
|
||||
pre_execute_cb=None,
|
||||
v3_data=None,
|
||||
expected_outputs=None,
|
||||
):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@ -522,9 +545,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
|
||||
#that we just want to cull out each model run.
|
||||
allocator = comfy.memory_management.aimdo_allocator
|
||||
expected_outputs = get_expected_outputs_for_node(dynprompt, unique_id)
|
||||
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
|
||||
try:
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, expected_outputs=expected_outputs)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
|
||||
269
tests-unit/execution_test/expected_outputs_test.py
Normal file
269
tests-unit/execution_test/expected_outputs_test.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""Unit tests for the expected_outputs feature.
|
||||
|
||||
This feature allows nodes to know at runtime which outputs are connected downstream,
|
||||
enabling them to skip computing outputs that aren't needed.
|
||||
"""
|
||||
|
||||
from comfy_api.latest import IO
|
||||
from comfy_execution.graph import DynamicPrompt, get_expected_outputs_for_node
|
||||
from comfy_execution.utils import (
|
||||
CurrentNodeContext,
|
||||
ExecutionContext,
|
||||
get_executing_context,
|
||||
)
|
||||
|
||||
|
||||
class TestGetExpectedOutputsForNode:
|
||||
"""Tests for get_expected_outputs_for_node() function."""
|
||||
|
||||
def test_single_output_connected(self):
|
||||
"""Test node with single output connected to one downstream node."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_multiple_outputs_partial_connected(self):
|
||||
"""Test node with multiple outputs, only some connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
# Output 1 is not connected
|
||||
"3": {"class_type": "ConsumerC", "inputs": {"input": ["1", 2]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 2})
|
||||
assert 1 not in expected # Output 1 is definitely unused
|
||||
|
||||
def test_no_outputs_connected(self):
|
||||
"""Test node with no outputs connected."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "OtherNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_same_output_connected_multiple_times(self):
|
||||
"""Test same output connected to multiple downstream nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "ConsumerA", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "ConsumerB", "inputs": {"input": ["1", 0]}},
|
||||
"4": {"class_type": "ConsumerC", "inputs": {"input": ["1", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
def test_node_not_in_prompt(self):
|
||||
"""Test getting expected outputs for a node not in the prompt."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "999")
|
||||
assert expected == frozenset()
|
||||
|
||||
def test_chained_nodes(self):
|
||||
"""Test expected outputs in a chain of nodes."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {"class_type": "MiddleNode", "inputs": {"input": ["1", 0]}},
|
||||
"3": {"class_type": "EndNode", "inputs": {"input": ["2", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1's output 0 is connected to node 2
|
||||
expected_1 = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected_1 == frozenset({0})
|
||||
|
||||
# Node 2's output 0 is connected to node 3
|
||||
expected_2 = get_expected_outputs_for_node(dynprompt, "2")
|
||||
assert expected_2 == frozenset({0})
|
||||
|
||||
# Node 3 has no downstream connections
|
||||
expected_3 = get_expected_outputs_for_node(dynprompt, "3")
|
||||
assert expected_3 == frozenset()
|
||||
|
||||
def test_complex_graph(self):
|
||||
"""Test expected outputs in a complex graph with multiple connections."""
|
||||
prompt = {
|
||||
"1": {"class_type": "MultiOutputNode", "inputs": {}},
|
||||
"2": {"class_type": "ProcessorA", "inputs": {"image": ["1", 0], "mask": ["1", 1]}},
|
||||
"3": {"class_type": "ProcessorB", "inputs": {"data": ["1", 2]}},
|
||||
"4": {"class_type": "Combiner", "inputs": {"a": ["2", 0], "b": ["3", 0]}},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
|
||||
# Node 1 has outputs 0, 1, 2 all connected
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0, 1, 2})
|
||||
|
||||
def test_constant_inputs_ignored(self):
|
||||
"""Test that constant (non-link) inputs don't affect expected outputs."""
|
||||
prompt = {
|
||||
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||
"2": {
|
||||
"class_type": "ConsumerNode",
|
||||
"inputs": {
|
||||
"image": ["1", 0],
|
||||
"value": 42,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
dynprompt = DynamicPrompt(prompt)
|
||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||
assert expected == frozenset({0})
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext with expected_outputs field."""
|
||||
|
||||
def test_context_with_expected_outputs(self):
|
||||
"""Test creating ExecutionContext with expected_outputs."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=0, expected_outputs=frozenset({0, 2})
|
||||
)
|
||||
assert ctx.prompt_id == "prompt-123"
|
||||
assert ctx.node_id == "node-456"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 2})
|
||||
|
||||
def test_context_without_expected_outputs(self):
|
||||
"""Test ExecutionContext defaults to None for expected_outputs."""
|
||||
ctx = ExecutionContext(prompt_id="prompt-123", node_id="node-456", list_index=0)
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_context_empty_expected_outputs(self):
|
||||
"""Test ExecutionContext with empty expected_outputs set."""
|
||||
ctx = ExecutionContext(
|
||||
prompt_id="prompt-123", node_id="node-456", list_index=None, expected_outputs=frozenset()
|
||||
)
|
||||
assert ctx.expected_outputs == frozenset()
|
||||
assert len(ctx.expected_outputs) == 0
|
||||
|
||||
|
||||
class TestCurrentNodeContext:
|
||||
"""Tests for CurrentNodeContext context manager with expected_outputs."""
|
||||
|
||||
def test_context_manager_with_expected_outputs(self):
|
||||
"""Test CurrentNodeContext sets and resets context correctly."""
|
||||
assert get_executing_context() is None
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 1})):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.prompt_id == "prompt-1"
|
||||
assert ctx.node_id == "node-1"
|
||||
assert ctx.list_index == 0
|
||||
assert ctx.expected_outputs == frozenset({0, 1})
|
||||
|
||||
assert get_executing_context() is None
|
||||
|
||||
def test_context_manager_without_expected_outputs(self):
|
||||
"""Test CurrentNodeContext works without expected_outputs (backwards compatible)."""
|
||||
with CurrentNodeContext("prompt-1", "node-1"):
|
||||
ctx = get_executing_context()
|
||||
assert ctx is not None
|
||||
assert ctx.expected_outputs is None
|
||||
|
||||
def test_nested_context_managers(self):
|
||||
"""Test nested CurrentNodeContext managers."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0})):
|
||||
ctx1 = get_executing_context()
|
||||
assert ctx1.expected_outputs == frozenset({0})
|
||||
|
||||
with CurrentNodeContext("prompt-1", "node-2", 0, frozenset({1, 2})):
|
||||
ctx2 = get_executing_context()
|
||||
assert ctx2.expected_outputs == frozenset({1, 2})
|
||||
assert ctx2.node_id == "node-2"
|
||||
|
||||
# After inner context exits, should be back to outer context
|
||||
ctx1_again = get_executing_context()
|
||||
assert ctx1_again.expected_outputs == frozenset({0})
|
||||
assert ctx1_again.node_id == "node-1"
|
||||
|
||||
def test_output_check_pattern(self):
|
||||
"""Test the typical pattern nodes will use to check expected outputs."""
|
||||
with CurrentNodeContext("prompt-1", "node-1", 0, frozenset({0, 2})):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Typical usage pattern
|
||||
if ctx and ctx.expected_outputs is not None:
|
||||
should_compute_0 = 0 in ctx.expected_outputs
|
||||
should_compute_1 = 1 in ctx.expected_outputs
|
||||
should_compute_2 = 2 in ctx.expected_outputs
|
||||
else:
|
||||
# Fallback when info not available
|
||||
should_compute_0 = should_compute_1 = should_compute_2 = True
|
||||
|
||||
assert should_compute_0 is True
|
||||
assert should_compute_1 is False # Not in expected_outputs
|
||||
assert should_compute_2 is True
|
||||
|
||||
|
||||
class TestSchemaLazyOutputs:
|
||||
"""Tests for lazy_outputs in V3 Schema."""
|
||||
|
||||
def test_schema_lazy_outputs_default(self):
|
||||
"""Test that lazy_outputs defaults to False."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is False
|
||||
|
||||
def test_schema_lazy_outputs_true(self):
|
||||
"""Test setting lazy_outputs to True."""
|
||||
schema = IO.Schema(
|
||||
node_id="TestNode",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
assert schema.lazy_outputs is True
|
||||
|
||||
def test_v3_node_lazy_outputs_property(self):
|
||||
"""Test that LAZY_OUTPUTS property works on V3 nodes."""
|
||||
|
||||
class TestNodeWithLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithLazyOutputs",
|
||||
lazy_outputs=True,
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithLazyOutputs.LAZY_OUTPUTS is True
|
||||
|
||||
def test_v3_node_lazy_outputs_default(self):
|
||||
"""Test that LAZY_OUTPUTS defaults to False on V3 nodes."""
|
||||
|
||||
class TestNodeWithoutLazyOutputs(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TestNodeWithoutLazyOutputs",
|
||||
inputs=[],
|
||||
outputs=[IO.Float.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls):
|
||||
return IO.NodeOutput(1.0)
|
||||
|
||||
assert TestNodeWithoutLazyOutputs.LAZY_OUTPUTS is False
|
||||
@ -574,6 +574,104 @@ class TestExecution:
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
def test_expected_outputs_all_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs contains all connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, all connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Connect all 3 outputs to preview nodes
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# All outputs should be white (255) since all are connected
|
||||
images0 = result.get_images(output0)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White pixels = 255, meaning output was in expected_outputs
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_partial_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs only contains connected outputs."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only some connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect outputs 0 and 2, leave output 1 disconnected
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
# output1 is intentionally not connected
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
# Connected outputs should be white (255)
|
||||
images0 = result.get_images(output0)
|
||||
images2 = result.get_images(output2)
|
||||
|
||||
assert len(images0) == 1, "Should have 1 image for output0"
|
||||
assert len(images2) == 1, "Should have 1 image for output2"
|
||||
|
||||
# White = expected, output 1 is not connected so we can't verify it directly but outputs 0 and 2 should be white
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white (was expected)"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_single_connected(self, client: ComfyClient, builder: GraphBuilder):
|
||||
"""Test that expected_outputs works with single connected output."""
|
||||
g = builder
|
||||
# Create a node with 3 outputs, only one connected
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=64, width=64)
|
||||
|
||||
# Only connect output 1
|
||||
output1 = g.node("PreviewImage", images=expected_outputs_node.out(1))
|
||||
|
||||
result = client.run(g)
|
||||
|
||||
images1 = result.get_images(output1)
|
||||
assert len(images1) == 1, "Should have 1 image for output1"
|
||||
|
||||
# Output 1 should be white (connected), others are not visible in this test
|
||||
assert numpy.array(images1[0]).min() == 255, "Output 1 should be white (was expected)"
|
||||
|
||||
def test_expected_outputs_cache_invalidation(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
"""Test that cache invalidates when output connections change."""
|
||||
g = builder
|
||||
# Use unique dimensions to avoid cache collision with other expected_outputs tests
|
||||
expected_outputs_node = g.node("TestExpectedOutputs", height=32, width=32)
|
||||
|
||||
# First run: only connect output 0
|
||||
output0 = g.node("PreviewImage", images=expected_outputs_node.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
assert result1.did_run(expected_outputs_node), "First run should execute the node"
|
||||
|
||||
# Second run: same connections, should be cached
|
||||
result2 = client.run(g)
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(expected_outputs_node), "Second run should be cached"
|
||||
|
||||
# Third run: add connection to output 2
|
||||
output2 = g.node("PreviewImage", images=expected_outputs_node.out(2))
|
||||
|
||||
result3 = client.run(g)
|
||||
# Because LAZY_OUTPUTS=True, changing connections should invalidate cache
|
||||
if server["should_cache_results"]:
|
||||
assert result3.did_run(expected_outputs_node), "Adding output connection should invalidate cache"
|
||||
|
||||
# Verify both outputs are now white
|
||||
images0 = result3.get_images(output0)
|
||||
images2 = result3.get_images(output2)
|
||||
assert numpy.array(images0[0]).min() == 255, "Output 0 should be white"
|
||||
assert numpy.array(images2[0]).min() == 255, "Output 2 should be white"
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
@ -6,6 +6,7 @@ from .tools import VariantSupport
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
@ -482,6 +483,57 @@ class TestOutputNodeWithSocketOutput:
|
||||
result = image * value
|
||||
return (result,)
|
||||
|
||||
|
||||
class TestExpectedOutputs:
|
||||
"""Test node for the expected_outputs feature.
|
||||
|
||||
This node has 3 IMAGE outputs that encode which outputs were expected:
|
||||
- White image (255) if the output was in expected_outputs
|
||||
- Black image (0) if the output was NOT in expected_outputs
|
||||
|
||||
This allows integration tests to verify which outputs were expected by checking pixel values.
|
||||
"""
|
||||
LAZY_OUTPUTS = True # Opt into cache invalidation on output connection changes
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"height": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
"width": ("INT", {"default": 64, "min": 1, "max": 1024}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
|
||||
RETURN_NAMES = ("output0", "output1", "output2")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def execute(self, height, width):
|
||||
ctx = get_executing_context()
|
||||
|
||||
# Default: assume all outputs are expected (backwards compatibility)
|
||||
output0_expected = True
|
||||
output1_expected = True
|
||||
output2_expected = True
|
||||
|
||||
if ctx is not None and ctx.expected_outputs is not None:
|
||||
output0_expected = 0 in ctx.expected_outputs
|
||||
output1_expected = 1 in ctx.expected_outputs
|
||||
output2_expected = 2 in ctx.expected_outputs
|
||||
|
||||
# Return white image if expected, black if not
|
||||
# This allows tests to verify which outputs were expected via pixel values
|
||||
white = torch.ones(1, height, width, 3)
|
||||
black = torch.zeros(1, height, width, 3)
|
||||
|
||||
return (
|
||||
white if output0_expected else black,
|
||||
white if output1_expected else black,
|
||||
white if output2_expected else black,
|
||||
)
|
||||
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
@ -498,6 +550,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestSleep": TestSleep,
|
||||
"TestParallelSleep": TestParallelSleep,
|
||||
"TestOutputNodeWithSocketOutput": TestOutputNodeWithSocketOutput,
|
||||
"TestExpectedOutputs": TestExpectedOutputs,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@ -516,4 +569,5 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestSleep": "Test Sleep",
|
||||
"TestParallelSleep": "Test Parallel Sleep",
|
||||
"TestOutputNodeWithSocketOutput": "Test Output Node With Socket Output",
|
||||
"TestExpectedOutputs": "Test Expected Outputs",
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user