From fce22da313d8dc7f5c0d08d71cd86d65601aa7c1 Mon Sep 17 00:00:00 2001 From: xmarre Date: Mon, 16 Mar 2026 09:29:00 +0100 Subject: [PATCH] Prevent signature traversal of raw --- comfy_execution/caching.py | 8 +-- tests/execution/test_caching.py | 103 ++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 4 deletions(-) create mode 100644 tests/execution/test_caching.py diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index d3a7ec52a..0a1bd188c 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -465,9 +465,9 @@ class CacheKeySetInputSignature(CacheKeySet): async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): """Build the immediate cache-signature fragment for a node. - Link inputs are reduced to ancestor references here, while non-link - values are appended as-is. Full canonicalization happens later in - `get_node_signature()` via `_signature_to_hashable()`. + Link inputs are reduced to ancestor references here. Non-link values + are canonicalized or failed closed before being appended so the outer + node-signature pass never recurses into live prompt input containers. """ if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. @@ -485,7 +485,7 @@ class CacheKeySetInputSignature(CacheKeySet): ancestor_index = ancestor_order_mapping[ancestor_id] signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: - signature.append((key, inputs[key])) + signature.append((key, to_hashable(inputs[key]))) return signature # This function returns a list of all ancestors of the given node. The order of the list is diff --git a/tests/execution/test_caching.py b/tests/execution/test_caching.py new file mode 100644 index 000000000..f5851c980 --- /dev/null +++ b/tests/execution/test_caching.py @@ -0,0 +1,103 @@ +import asyncio + +from comfy_execution import caching + + +class _StubDynPrompt: + def __init__(self, nodes): + self._nodes = nodes + + def has_node(self, node_id): + return node_id in self._nodes + + def get_node(self, node_id): + return self._nodes[node_id] + + +class _StubIsChangedCache: + async def get(self, node_id): + return None + + +class _StubNode: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + +def test_get_immediate_node_signature_canonicalizes_non_link_inputs(monkeypatch): + live_value = [1, {"nested": [2, 3]}] + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {})) + + assert signature == [ + "TestCacheNode", + None, + ("value", ("list", (1, ("dict", (("nested", ("list", (2, 3))),))))), + ] + + +def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(monkeypatch): + class OpaqueRuntimeValue: + pass + + live_value = OpaqueRuntimeValue() + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {})) + + assert signature[:2] == ["TestCacheNode", None] + assert signature[2][0] == "value" + assert type(signature[2][1]) is caching.Unhashable + + +def test_get_node_signature_never_visits_raw_non_link_input(monkeypatch): + live_value = [1, 2, 3] + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + original_impl = caching._signature_to_hashable_impl + + def guarded_impl(obj, *args, **kwargs): + if obj is live_value: + raise AssertionError("raw non-link input reached outer signature canonicalizer") + return original_impl(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", guarded_impl) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_node_signature(dynprompt, "1")) + + assert isinstance(signature, tuple)