From 01ef4e50ec8fc4e5cd7ef659f309a6f43ad1e9d8 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 6 Feb 2026 08:29:01 +0200 Subject: [PATCH] =?UTF-8?q?fix:=20precompute=20expected=5Foutputs=20map=20?= =?UTF-8?q?to=20avoid=20O(n=C2=B2)=20graph=20traversal?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- comfy_execution/graph.py | 36 ++++++++++++------- .../execution_test/expected_outputs_test.py | 24 +++++++++++++ 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 689fd23fc..b9cf9e943 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -25,19 +25,7 @@ def get_expected_outputs_for_node(dynprompt, node_id: str) -> frozenset: Returns outputs that MIGHT be used. Outputs NOT in this set are DEFINITELY not used and safe to skip. """ - expected = set() - for other_node_id in dynprompt.all_node_ids(): - try: - node_data = dynprompt.get_node(other_node_id) - except NodeNotFoundError: - continue - inputs = node_data.get("inputs", {}) - for input_name, value in inputs.items(): - if is_link(value): - from_node_id, from_socket = value - if from_node_id == node_id: - expected.add(from_socket) - return frozenset(expected) + return dynprompt.get_expected_outputs_map().get(node_id, frozenset()) class DynamicPrompt: @@ -48,6 +36,7 @@ class DynamicPrompt: self.ephemeral_prompt = {} self.ephemeral_parents = {} self.ephemeral_display = {} + self._expected_outputs_map = None def get_node(self, node_id): if node_id in self.ephemeral_prompt: @@ -63,6 +52,7 @@ class DynamicPrompt: self.ephemeral_prompt[node_id] = node_info self.ephemeral_parents[node_id] = parent_id self.ephemeral_display[node_id] = display_id + self._expected_outputs_map = None def get_real_node_id(self, node_id): while node_id in self.ephemeral_parents: @@ -80,6 +70,26 @@ class DynamicPrompt: def all_node_ids(self): return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + def _build_expected_outputs_map(self): + result = {} + for node_id in self.all_node_ids(): + try: + node_data = self.get_node(node_id) + except NodeNotFoundError: + continue + for value in node_data.get("inputs", {}).values(): + if is_link(value): + from_node_id, from_socket = value + if from_node_id not in result: + result[from_node_id] = set() + result[from_node_id].add(from_socket) + self._expected_outputs_map = {k: frozenset(v) for k, v in result.items()} + + def get_expected_outputs_map(self): + if self._expected_outputs_map is None: + self._build_expected_outputs_map() + return self._expected_outputs_map + def get_original_prompt(self): return self.original_prompt diff --git a/tests-unit/execution_test/expected_outputs_test.py b/tests-unit/execution_test/expected_outputs_test.py index b96902826..b40786085 100644 --- a/tests-unit/execution_test/expected_outputs_test.py +++ b/tests-unit/execution_test/expected_outputs_test.py @@ -123,6 +123,30 @@ class TestGetExpectedOutputsForNode: expected = get_expected_outputs_for_node(dynprompt, "1") assert expected == frozenset({0}) + def test_ephemeral_node_invalidates_cache(self): + """Test that adding ephemeral nodes updates expected outputs.""" + prompt = { + "1": {"class_type": "SourceNode", "inputs": {}}, + "2": {"class_type": "ConsumerNode", "inputs": {"image": ["1", 0]}}, + } + dynprompt = DynamicPrompt(prompt) + + # Initially only output 0 is connected + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0}) + + # Add an ephemeral node that connects to output 1 + dynprompt.add_ephemeral_node( + "eph_1", + {"class_type": "EphemeralNode", "inputs": {"data": ["1", 1]}}, + parent_id="2", + display_id="2", + ) + + # Now both outputs 0 and 1 should be expected + expected = get_expected_outputs_for_node(dynprompt, "1") + assert expected == frozenset({0, 1}) + class TestExecutionContext: """Tests for ExecutionContext with expected_outputs field."""