mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 20:12:35 +08:00
fix: precompute expected_outputs map to avoid O(n²) graph traversal
This commit is contained in:
parent
50975a7a0d
commit
01ef4e50ec
@ -25,19 +25,7 @@ def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
|
|||||||
Returns outputs that MIGHT be used.
|
Returns outputs that MIGHT be used.
|
||||||
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
Outputs NOT in this set are DEFINITELY not used and safe to skip.
|
||||||
"""
|
"""
|
||||||
expected = set()
|
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
|
||||||
for other_node_id in dynprompt.all_node_ids():
|
|
||||||
try:
|
|
||||||
node_data = dynprompt.get_node(other_node_id)
|
|
||||||
except NodeNotFoundError:
|
|
||||||
continue
|
|
||||||
inputs = node_data.get("inputs", {})
|
|
||||||
for input_name, value in inputs.items():
|
|
||||||
if is_link(value):
|
|
||||||
from_node_id, from_socket = value
|
|
||||||
if from_node_id == node_id:
|
|
||||||
expected.add(from_socket)
|
|
||||||
return frozenset(expected)
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicPrompt:
|
class DynamicPrompt:
|
||||||
@ -48,6 +36,7 @@ class DynamicPrompt:
|
|||||||
self.ephemeral_prompt = {}
|
self.ephemeral_prompt = {}
|
||||||
self.ephemeral_parents = {}
|
self.ephemeral_parents = {}
|
||||||
self.ephemeral_display = {}
|
self.ephemeral_display = {}
|
||||||
|
self._expected_outputs_map = None
|
||||||
|
|
||||||
def get_node(self, node_id):
|
def get_node(self, node_id):
|
||||||
if node_id in self.ephemeral_prompt:
|
if node_id in self.ephemeral_prompt:
|
||||||
@ -63,6 +52,7 @@ class DynamicPrompt:
|
|||||||
self.ephemeral_prompt[node_id] = node_info
|
self.ephemeral_prompt[node_id] = node_info
|
||||||
self.ephemeral_parents[node_id] = parent_id
|
self.ephemeral_parents[node_id] = parent_id
|
||||||
self.ephemeral_display[node_id] = display_id
|
self.ephemeral_display[node_id] = display_id
|
||||||
|
self._expected_outputs_map = None
|
||||||
|
|
||||||
def get_real_node_id(self, node_id):
|
def get_real_node_id(self, node_id):
|
||||||
while node_id in self.ephemeral_parents:
|
while node_id in self.ephemeral_parents:
|
||||||
@ -80,6 +70,26 @@ class DynamicPrompt:
|
|||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||||
|
|
||||||
|
def _build_expected_outputs_map(self):
|
||||||
|
result = {}
|
||||||
|
for node_id in self.all_node_ids():
|
||||||
|
try:
|
||||||
|
node_data = self.get_node(node_id)
|
||||||
|
except NodeNotFoundError:
|
||||||
|
continue
|
||||||
|
for value in node_data.get("inputs", {}).values():
|
||||||
|
if is_link(value):
|
||||||
|
from_node_id, from_socket = value
|
||||||
|
if from_node_id not in result:
|
||||||
|
result[from_node_id] = set()
|
||||||
|
result[from_node_id].add(from_socket)
|
||||||
|
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
|
||||||
|
|
||||||
|
def get_expected_outputs_map(self):
|
||||||
|
if self._expected_outputs_map is None:
|
||||||
|
self._build_expected_outputs_map()
|
||||||
|
return self._expected_outputs_map
|
||||||
|
|
||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
|
|||||||
@ -123,6 +123,30 @@ class TestGetExpectedOutputsForNode:
|
|||||||
expected = get_expected_outputs_for_node(dynprompt, "1")
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||||
assert expected == frozenset({0})
|
assert expected == frozenset({0})
|
||||||
|
|
||||||
|
def test_ephemeral_node_invalidates_cache(self):
|
||||||
|
"""Test that adding ephemeral nodes updates expected outputs."""
|
||||||
|
prompt = {
|
||||||
|
"1": {"class_type": "SourceNode", "inputs": {}},
|
||||||
|
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
|
||||||
|
}
|
||||||
|
dynprompt = DynamicPrompt(prompt)
|
||||||
|
|
||||||
|
# Initially only output 0 is connected
|
||||||
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||||
|
assert expected == frozenset({0})
|
||||||
|
|
||||||
|
# Add an ephemeral node that connects to output 1
|
||||||
|
dynprompt.add_ephemeral_node(
|
||||||
|
"eph_1",
|
||||||
|
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
|
||||||
|
parent_id="2",
|
||||||
|
display_id="2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now both outputs 0 and 1 should be expected
|
||||||
|
expected = get_expected_outputs_for_node(dynprompt, "1")
|
||||||
|
assert expected == frozenset({0, 1})
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionContext:
|
class TestExecutionContext:
|
||||||
"""Tests for ExecutionContext with expected_outputs field."""
|
"""Tests for ExecutionContext with expected_outputs field."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user