fix: precompute expected_outputs map to avoid O(n²) graph traversal

This commit is contained in:
bigcat88 2026-02-06 08:29:01 +02:00
parent 50975a7a0d
commit 01ef4e50ec
2 changed files with 47 additions and 13 deletions

View File

@ -25,19 +25,7 @@ def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip.
"""
expected = set()
for other_node_id in dynprompt.all_node_ids():
try:
node_data = dynprompt.get_node(other_node_id)
except NodeNotFoundError:
continue
inputs = node_data.get("inputs", {})
for input_name, value in inputs.items():
if is_link(value):
from_node_id, from_socket = value
if from_node_id == node_id:
expected.add(from_socket)
return frozenset(expected)
return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
class DynamicPrompt:
@ -48,6 +36,7 @@ class DynamicPrompt:
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
self._expected_outputs_map = None
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
@ -63,6 +52,7 @@ class DynamicPrompt:
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
self._expected_outputs_map = None
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
@ -80,6 +70,26 @@ class DynamicPrompt:
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def _build_expected_outputs_map(self):
result = {}
for node_id in self.all_node_ids():
try:
node_data = self.get_node(node_id)
except NodeNotFoundError:
continue
for value in node_data.get("inputs", {}).values():
if is_link(value):
from_node_id, from_socket = value
if from_node_id not in result:
result[from_node_id] = set()
result[from_node_id].add(from_socket)
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
def get_expected_outputs_map(self):
if self._expected_outputs_map is None:
self._build_expected_outputs_map()
return self._expected_outputs_map
def get_original_prompt(self):
return self.original_prompt

View File

@ -123,6 +123,30 @@ class TestGetExpectedOutputsForNode:
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
def test_ephemeral_node_invalidates_cache(self):
"""Test that adding ephemeral nodes updates expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Initially only output 0 is connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
# Add an ephemeral node that connects to output 1
dynprompt.add_ephemeral_node(
"eph_1",
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
parent_id="2",
display_id="2",
)
# Now both outputs 0 and 1 should be expected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1})
class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field."""