From 9feb26928cd755d751a04e3f6e46c12485a57be5 Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 04:31:32 +0100 Subject: [PATCH] Change signature cache to bail early --- comfy_execution/caching.py | 45 ++++++++++++---- tests-unit/execution_test/caching_test.py | 65 ++++++++++++++++------- 2 files changed, 79 insertions(+), 31 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 08f3f436b..73b67f8ab 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -66,6 +66,12 @@ _MAX_SIGNATURE_DEPTH = 32 _MAX_SIGNATURE_CONTAINER_VISITS = 10_000 +def _mark_signature_tainted(taint_state): + """Record that signature sanitization hit a fail-closed condition.""" + if taint_state is not None: + taint_state["tainted"] = True + + def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None): """Return a deterministic ordering key for sanitized built-in container content.""" if depth >= max_depth: @@ -117,15 +123,20 @@ def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=Non return result -def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None): +def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None, taint_state=None): """Normalize signature inputs to safe built-in containers. Preserves built-in container type, replaces opaque runtime values with - Unhashable(), stops safely on cycles or excessive depth, and memoizes - repeated built-in substructures so shared DAG-like inputs do not explode - into repeated recursive work. + Unhashable(), stops safely on cycles or excessive depth, memoizes repeated + built-in substructures so shared DAG-like inputs do not explode into + repeated recursive work, and optionally records when sanitization had to + fail closed anywhere in the traversed structure. """ + if taint_state is not None and taint_state.get("tainted"): + return Unhashable() + if depth >= max_depth: + _mark_signature_tainted(taint_state) return Unhashable() if active is None: @@ -139,16 +150,19 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti if obj_type in _PRIMITIVE_SIGNATURE_TYPES: return obj if obj_type not in _CONTAINER_SIGNATURE_TYPES: + _mark_signature_tainted(taint_state) return Unhashable() obj_id = id(obj) if obj_id in memo: return memo[obj_id] if obj_id in active: + _mark_signature_tainted(taint_state) return Unhashable() budget["remaining"] -= 1 if budget["remaining"] < 0: + _mark_signature_tainted(taint_state) return Unhashable() active.add(obj_id) @@ -159,8 +173,8 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti sort_memo = {} sanitized_items = [ ( - _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget), - _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget), + _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget, taint_state), + _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget, taint_state), ) for key, value in items ] @@ -181,34 +195,40 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti previous_sort_key, previous_item = ordered_items[index - 1] current_sort_key, current_item = ordered_items[index] if previous_sort_key == current_sort_key and previous_item != current_item: + _mark_signature_tainted(taint_state) break else: result = {key: value for _, (key, value) in ordered_items} except RuntimeError: + _mark_signature_tainted(taint_state) result = Unhashable() elif obj_type is list: try: items = list(obj) - result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items] + result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items] except RuntimeError: + _mark_signature_tainted(taint_state) result = Unhashable() elif obj_type is tuple: try: items = list(obj) - result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items) + result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items) except RuntimeError: + _mark_signature_tainted(taint_state) result = Unhashable() elif obj_type is set: try: items = list(obj) - result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items} + result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items} except RuntimeError: + _mark_signature_tainted(taint_state) result = Unhashable() else: try: items = list(obj) - result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items) + result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items) except RuntimeError: + _mark_signature_tainted(taint_state) result = Unhashable() finally: active.discard(obj_id) @@ -377,7 +397,10 @@ class CacheKeySetInputSignature(CacheKeySet): signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) - signature = _sanitize_signature_input(signature) + taint_state = {"tainted": False} + signature = _sanitize_signature_input(signature, taint_state=taint_state) + if taint_state["tainted"]: + return Unhashable() return to_hashable(signature) async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py index 6313faed1..390efe87b 100644 --- a/tests-unit/execution_test/caching_test.py +++ b/tests-unit/execution_test/caching_test.py @@ -49,22 +49,6 @@ class _OpaqueValue: """Hashable opaque object used to exercise fail-closed unordered hashing paths.""" -def _contains_unhashable(value, unhashable_type): - """Return whether a nested built-in structure contains an Unhashable sentinel.""" - if isinstance(value, unhashable_type): - return True - - value_type = type(value) - if value_type is dict: - return any( - _contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type) - for key, item in value.items() - ) - if value_type in (list, tuple, set, frozenset): - return any(_contains_unhashable(item, unhashable_type) for item in value) - return False - - @pytest.fixture def caching_module(monkeypatch): """Import `comfy_execution.caching` with lightweight stub dependencies.""" @@ -105,6 +89,43 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m assert sanitized[0][1]["value"] == 2 +def test_sanitize_signature_input_marks_tainted_on_opaque_values(caching_module): + """Opaque values should mark the containing signature as tainted.""" + caching, _ = caching_module + taint_state = {"tainted": False} + + sanitized = caching._sanitize_signature_input(["safe", object()], taint_state=taint_state) + + assert isinstance(sanitized, list) + assert taint_state["tainted"] is True + assert isinstance(sanitized[1], caching.Unhashable) + + +def test_sanitize_signature_input_stops_descending_after_taint(caching_module, monkeypatch): + """Once tainted, later recursive calls should return immediately without deeper descent.""" + caching, _ = caching_module + original = caching._sanitize_signature_input + marker = object() + marker_seen = False + + def tracking_sanitize(obj, *args, **kwargs): + """Track whether recursion reaches the nested marker after tainting.""" + nonlocal marker_seen + if obj is marker: + marker_seen = True + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitize_signature_input", tracking_sanitize) + + taint_state = {"tainted": False} + sanitized = original([object(), [marker]], taint_state=taint_state) + + assert isinstance(sanitized, list) + assert taint_state["tainted"] is True + assert marker_seen is False + assert isinstance(sanitized[1], caching.Unhashable) + + def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch): """List sanitization should read a point-in-time snapshot before recursive descent.""" caching, _ = caching_module @@ -241,10 +262,15 @@ def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, assert isinstance(hashable, caching.Unhashable) -def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch): - """Recursive `is_changed` payloads should be sanitized inside the full node signature.""" +def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch): + """Tainted full signatures should fail closed before `to_hashable()` runs.""" caching, nodes_module = caching_module monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) + monkeypatch.setattr( + caching, + "to_hashable", + lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"), + ) is_changed_value = [] is_changed_value.append(is_changed_value) @@ -265,5 +291,4 @@ def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch signature = asyncio.run(key_set.get_node_signature(dynprompt, "node")) - assert signature[0] == "list" - assert _contains_unhashable(signature, caching.Unhashable) + assert isinstance(signature, caching.Unhashable)