fix: precompute expected_outputs map to avoid O(n²) graph traversal

This commit is contained in:
bigcat88 2026-02-06 08:29:01 +02:00
parent 50975a7a0d
commit 01ef4e50ec
2 changed files with 47 additions and 13 deletions

View File

@ -25,19 +25,7 @@ def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset:
Returns outputs that MIGHT be used. Returns outputs that MIGHT be used.
Outputs NOT in this set are DEFINITELY not used and safe to skip. Outputs NOT in this set are DEFINITELY not used and safe to skip.
""" """
expected = set() return dynprompt.get_expected_outputs_map().get(node_id, frozenset())
for other_node_id in dynprompt.all_node_ids():
try:
node_data = dynprompt.get_node(other_node_id)
except NodeNotFoundError:
continue
inputs = node_data.get("inputs", {})
for input_name, value in inputs.items():
if is_link(value):
from_node_id, from_socket = value
if from_node_id == node_id:
expected.add(from_socket)
return frozenset(expected)
class DynamicPrompt: class DynamicPrompt:
@ -48,6 +36,7 @@ class DynamicPrompt:
self.ephemeral_prompt = {} self.ephemeral_prompt = {}
self.ephemeral_parents = {} self.ephemeral_parents = {}
self.ephemeral_display = {} self.ephemeral_display = {}
self._expected_outputs_map = None
def get_node(self, node_id): def get_node(self, node_id):
if node_id in self.ephemeral_prompt: if node_id in self.ephemeral_prompt:
@ -63,6 +52,7 @@ class DynamicPrompt:
self.ephemeral_prompt[node_id] = node_info self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id self.ephemeral_display[node_id] = display_id
self._expected_outputs_map = None
def get_real_node_id(self, node_id): def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents: while node_id in self.ephemeral_parents:
@ -80,6 +70,26 @@ class DynamicPrompt:
def all_node_ids(self): def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def _build_expected_outputs_map(self):
result = {}
for node_id in self.all_node_ids():
try:
node_data = self.get_node(node_id)
except NodeNotFoundError:
continue
for value in node_data.get("inputs", {}).values():
if is_link(value):
from_node_id, from_socket = value
if from_node_id not in result:
result[from_node_id] = set()
result[from_node_id].add(from_socket)
self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()}
def get_expected_outputs_map(self):
if self._expected_outputs_map is None:
self._build_expected_outputs_map()
return self._expected_outputs_map
def get_original_prompt(self): def get_original_prompt(self):
return self.original_prompt return self.original_prompt

View File

@ -123,6 +123,30 @@ class TestGetExpectedOutputsForNode:
expected = get_expected_outputs_for_node(dynprompt, "1") expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0}) assert expected == frozenset({0})
def test_ephemeral_node_invalidates_cache(self):
"""Test that adding ephemeral nodes updates expected outputs."""
prompt = {
"1": {"class_type": "SourceNode", "inputs": {}},
"2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}},
}
dynprompt = DynamicPrompt(prompt)
# Initially only output 0 is connected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0})
# Add an ephemeral node that connects to output 1
dynprompt.add_ephemeral_node(
"eph_1",
{"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}},
parent_id="2",
display_id="2",
)
# Now both outputs 0 and 1 should be expected
expected = get_expected_outputs_for_node(dynprompt, "1")
assert expected == frozenset({0, 1})
class TestExecutionContext: class TestExecutionContext:
"""Tests for ExecutionContext with expected_outputs field.""" """Tests for ExecutionContext with expected_outputs field."""