From 0b512198e896179019a85bfe3171f1cb076510bb Mon Sep 17 00:00:00 2001 From: xmarre Date: Sun, 15 Mar 2026 05:41:39 +0100 Subject: [PATCH] Adopt single-pass signature hashing --- comfy_execution/caching.py | 167 +++++++++++----------- tests-unit/execution_test/caching_test.py | 105 +++++++------- 2 files changed, 131 insertions(+), 141 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 73b67f8ab..f1b5227db 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -64,12 +64,13 @@ _PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None)) _CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset) _MAX_SIGNATURE_DEPTH = 32 _MAX_SIGNATURE_CONTAINER_VISITS = 10_000 +_FAILED_SIGNATURE = object() -def _mark_signature_tainted(taint_state): - """Record that signature sanitization hit a fail-closed condition.""" - if taint_state is not None: - taint_state["tainted"] = True +def _primitive_signature_sort_key(obj): + """Return a deterministic ordering key for primitive signature values.""" + obj_type = type(obj) + return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj)) def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None): @@ -123,21 +124,10 @@ def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=Non return result -def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None, taint_state=None): - """Normalize signature inputs to safe built-in containers. - - Preserves built-in container type, replaces opaque runtime values with - Unhashable(), stops safely on cycles or excessive depth, memoizes repeated - built-in substructures so shared DAG-like inputs do not explode into - repeated recursive work, and optionally records when sanitization had to - fail closed anywhere in the traversed structure. - """ - if taint_state is not None and taint_state.get("tainted"): - return Unhashable() - +def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None): + """Canonicalize signature inputs directly into their final hashable form.""" if depth >= max_depth: - _mark_signature_tainted(taint_state) - return Unhashable() + return _FAILED_SIGNATURE if active is None: active = set() @@ -148,93 +138,102 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti obj_type = type(obj) if obj_type in _PRIMITIVE_SIGNATURE_TYPES: - return obj - if obj_type not in _CONTAINER_SIGNATURE_TYPES: - _mark_signature_tainted(taint_state) - return Unhashable() + return obj, _primitive_signature_sort_key(obj) + if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES: + return _FAILED_SIGNATURE obj_id = id(obj) if obj_id in memo: return memo[obj_id] if obj_id in active: - _mark_signature_tainted(taint_state) - return Unhashable() + return _FAILED_SIGNATURE budget["remaining"] -= 1 if budget["remaining"] < 0: - _mark_signature_tainted(taint_state) - return Unhashable() + return _FAILED_SIGNATURE active.add(obj_id) try: if obj_type is dict: try: items = list(obj.items()) - sort_memo = {} - sanitized_items = [ - ( - _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget, taint_state), - _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget, taint_state), - ) - for key, value in items - ] - ordered_items = [ - ( - ( - _sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo), - _sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo), - ), - (key, value), - ) - for key, value in sanitized_items - ] - ordered_items.sort(key=lambda item: item[0]) + except RuntimeError: + return _FAILED_SIGNATURE - result = Unhashable() - for index in range(1, len(ordered_items)): - previous_sort_key, previous_item = ordered_items[index - 1] - current_sort_key, current_item = ordered_items[index] - if previous_sort_key == current_sort_key and previous_item != current_item: - _mark_signature_tainted(taint_state) - break - else: - result = {key: value for _, (key, value) in ordered_items} - except RuntimeError: - _mark_signature_tainted(taint_state) - result = Unhashable() - elif obj_type is list: + ordered_items = [] + for key, value in items: + key_result = _signature_to_hashable_impl(key, depth + 1, max_depth, active, memo, budget) + if key_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget) + if value_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + key_value, key_sort = key_result + value_value, value_sort = value_result + ordered_items.append((((key_sort, value_sort)), (key_value, value_value))) + + ordered_items.sort(key=lambda item: item[0]) + for index in range(1, len(ordered_items)): + previous_sort_key, previous_item = ordered_items[index - 1] + current_sort_key, current_item = ordered_items[index] + if previous_sort_key == current_sort_key and previous_item != current_item: + return _FAILED_SIGNATURE + + value = ("dict", tuple(item for _, item in ordered_items)) + sort_key = ("dict", tuple(sort_key for sort_key, _ in ordered_items)) + elif obj_type is list or obj_type is tuple: try: items = list(obj) - result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items] except RuntimeError: - _mark_signature_tainted(taint_state) - result = Unhashable() - elif obj_type is tuple: - try: - items = list(obj) - result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items) - except RuntimeError: - _mark_signature_tainted(taint_state) - result = Unhashable() - elif obj_type is set: - try: - items = list(obj) - result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items} - except RuntimeError: - _mark_signature_tainted(taint_state) - result = Unhashable() + return _FAILED_SIGNATURE + + child_results = [] + for item in items: + child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget) + if child_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + child_results.append(child_result) + + container_tag = "list" if obj_type is list else "tuple" + value = (container_tag, tuple(child for child, _ in child_results)) + sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results)) else: try: items = list(obj) - result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items) except RuntimeError: - _mark_signature_tainted(taint_state) - result = Unhashable() + return _FAILED_SIGNATURE + + ordered_items = [] + for item in items: + child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget) + if child_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + child_value, child_sort = child_result + ordered_items.append((child_sort, child_value)) + + ordered_items.sort(key=lambda item: item[0]) + for index in range(1, len(ordered_items)): + previous_sort_key, previous_value = ordered_items[index - 1] + current_sort_key, current_value = ordered_items[index] + if previous_sort_key == current_sort_key and previous_value != current_value: + return _FAILED_SIGNATURE + + container_tag = "set" if obj_type is set else "frozenset" + value = (container_tag, tuple(child_value for _, child_value in ordered_items)) + sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items)) finally: active.discard(obj_id) - memo[obj_id] = result - return result + memo[obj_id] = (value, sort_key) + return memo[obj_id] + + +def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): + """Build the final cache-signature representation in one fail-closed pass.""" + result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes}) + if result is _FAILED_SIGNATURE: + return Unhashable() + return result[0] def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): @@ -397,11 +396,7 @@ class CacheKeySetInputSignature(CacheKeySet): signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) - taint_state = {"tainted": False} - signature = _sanitize_signature_input(signature, taint_state=taint_state) - if taint_state["tainted"]: - return Unhashable() - return to_hashable(signature) + return _signature_to_hashable(signature) async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): """Build the cache-signature fragment for a node's immediate inputs. @@ -424,7 +419,7 @@ class CacheKeySetInputSignature(CacheKeySet): ancestor_index = ancestor_order_mapping[ancestor_id] signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: - signature.append((key, _sanitize_signature_input(inputs[key]))) + signature.append((key, inputs[key])) return signature # This function returns a list of all ancestors of the given node. The order of the list is diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py index 390efe87b..a7dffcee0 100644 --- a/tests-unit/execution_test/caching_test.py +++ b/tests-unit/execution_test/caching_test.py @@ -1,4 +1,4 @@ -"""Unit tests for cache-signature sanitization and hash conversion hardening.""" +"""Unit tests for cache-signature canonicalization hardening.""" import asyncio import importlib @@ -76,96 +76,91 @@ def caching_module(monkeypatch): return module, nodes_module -def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_module): - """Shared built-in substructures should sanitize without collapsing to Unhashable.""" +def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module): + """Shared built-in substructures should canonicalize without collapsing to Unhashable.""" caching, _ = caching_module shared = [{"value": 1}, {"value": 2}] - sanitized = caching._sanitize_signature_input([shared, shared]) + signature = caching._signature_to_hashable([shared, shared]) - assert isinstance(sanitized, list) - assert sanitized[0] == sanitized[1] - assert sanitized[0][0]["value"] == 1 - assert sanitized[0][1]["value"] == 2 + assert signature[0] == "list" + assert signature[1][0] == signature[1][1] + assert signature[1][0][0] == "list" + assert signature[1][0][1][0] == ("dict", (("value", 1),)) + assert signature[1][0][1][1] == ("dict", (("value", 2),)) -def test_sanitize_signature_input_marks_tainted_on_opaque_values(caching_module): - """Opaque values should mark the containing signature as tainted.""" +def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module): + """Opaque values should collapse the full signature to Unhashable immediately.""" caching, _ = caching_module - taint_state = {"tainted": False} - sanitized = caching._sanitize_signature_input(["safe", object()], taint_state=taint_state) + signature = caching._signature_to_hashable(["safe", object()]) - assert isinstance(sanitized, list) - assert taint_state["tainted"] is True - assert isinstance(sanitized[1], caching.Unhashable) + assert isinstance(signature, caching.Unhashable) -def test_sanitize_signature_input_stops_descending_after_taint(caching_module, monkeypatch): - """Once tainted, later recursive calls should return immediately without deeper descent.""" +def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch): + """Once canonicalization fails, later recursive descent should stop immediately.""" caching, _ = caching_module - original = caching._sanitize_signature_input + original = caching._signature_to_hashable_impl marker = object() marker_seen = False - def tracking_sanitize(obj, *args, **kwargs): - """Track whether recursion reaches the nested marker after tainting.""" + def tracking_canonicalize(obj, *args, **kwargs): + """Track whether recursion reaches the nested marker after failure.""" nonlocal marker_seen if obj is marker: marker_seen = True return original(obj, *args, **kwargs) - monkeypatch.setattr(caching, "_sanitize_signature_input", tracking_sanitize) + monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize) - taint_state = {"tainted": False} - sanitized = original([object(), [marker]], taint_state=taint_state) + signature = caching._signature_to_hashable([object(), [marker]]) - assert isinstance(sanitized, list) - assert taint_state["tainted"] is True + assert isinstance(signature, caching.Unhashable) assert marker_seen is False - assert isinstance(sanitized[1], caching.Unhashable) -def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch): - """List sanitization should read a point-in-time snapshot before recursive descent.""" +def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch): + """List canonicalization should read a point-in-time snapshot before recursive descent.""" caching, _ = caching_module - original = caching._sanitize_signature_input - marker = object() + original = caching._signature_to_hashable_impl + marker = ("marker",) values = [marker, 2] - def mutating_sanitize(obj, *args, **kwargs): + def mutating_canonicalize(obj, *args, **kwargs): """Mutate the live list during recursion to verify snapshot-based traversal.""" if obj is marker: values[1] = 3 return original(obj, *args, **kwargs) - monkeypatch.setattr(caching, "_sanitize_signature_input", mutating_sanitize) + monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize) - sanitized = original(values) + signature = caching._signature_to_hashable(values) - assert isinstance(sanitized, list) - assert sanitized[1] == 2 + assert signature == ("list", (("tuple", ("marker",)), 2)) + assert values[1] == 3 -def test_sanitize_signature_input_snapshots_dict_before_recursing(caching_module, monkeypatch): - """Dict sanitization should read a point-in-time snapshot before recursive descent.""" +def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch): + """Dict canonicalization should read a point-in-time snapshot before recursive descent.""" caching, _ = caching_module - original = caching._sanitize_signature_input - marker = object() + original = caching._signature_to_hashable_impl + marker = ("marker",) values = {"first": marker, "second": 2} - def mutating_sanitize(obj, *args, **kwargs): + def mutating_canonicalize(obj, *args, **kwargs): """Mutate the live dict during recursion to verify snapshot-based traversal.""" if obj is marker: values["second"] = 3 return original(obj, *args, **kwargs) - monkeypatch.setattr(caching, "_sanitize_signature_input", mutating_sanitize) + monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize) - sanitized = original(values) + signature = caching._signature_to_hashable(values) - assert isinstance(sanitized, dict) - assert sanitized["second"] == 2 + assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2))) + assert values["second"] == 3 @pytest.mark.parametrize( @@ -178,31 +173,31 @@ def test_sanitize_signature_input_snapshots_dict_before_recursing(caching_module lambda marker: {marker: "value"}, ], ) -def test_sanitize_signature_input_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): - """Traversal RuntimeError should degrade sanitization to Unhashable.""" +def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade canonicalization to Unhashable.""" caching, _ = caching_module - original = caching._sanitize_signature_input + original = caching._signature_to_hashable_impl marker = object() - def raising_sanitize(obj, *args, **kwargs): + def raising_canonicalize(obj, *args, **kwargs): """Raise a traversal RuntimeError for the marker value and delegate otherwise.""" if obj is marker: raise RuntimeError("container changed during iteration") return original(obj, *args, **kwargs) - monkeypatch.setattr(caching, "_sanitize_signature_input", raising_sanitize) + monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize) - sanitized = original(container_factory(marker)) + signature = caching._signature_to_hashable(container_factory(marker)) - assert isinstance(sanitized, caching.Unhashable) + assert isinstance(signature, caching.Unhashable) def test_to_hashable_handles_shared_builtin_substructures(caching_module): - """Repeated sanitized content should hash stably for shared substructures.""" + """The legacy helper should still hash sanitized built-ins stably when used directly.""" caching, _ = caching_module shared = [{"value": 1}, {"value": 2}] - sanitized = caching._sanitize_signature_input([shared, shared]) + sanitized = [shared, shared] hashable = caching.to_hashable(sanitized) assert hashable[0] == "list" @@ -232,7 +227,7 @@ def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, c assert isinstance(hashable, caching.Unhashable) -def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module): +def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module): """Ambiguous dict sort ties should fail closed instead of depending on input order.""" caching, _ = caching_module ambiguous = { @@ -240,7 +235,7 @@ def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(cachi _OpaqueValue(): _OpaqueValue(), } - sanitized = caching._sanitize_signature_input(ambiguous) + sanitized = caching._signature_to_hashable(ambiguous) assert isinstance(sanitized, caching.Unhashable)