diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index cbf2e9de1..1ca1edcc0 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -60,136 +60,239 @@ class Unhashable: pass -def _sanitized_sort_key(obj, depth=0, max_depth=32): +_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None)) +_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset) +_MAX_SIGNATURE_DEPTH = 32 +_MAX_SIGNATURE_CONTAINER_VISITS = 10_000 + + +def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None): """Return a deterministic ordering key for sanitized built-in container content.""" if depth >= max_depth: return ("MAX_DEPTH",) + if active is None: + active = set() + if memo is None: + memo = {} + obj_type = type(obj) if obj_type is Unhashable: return ("UNHASHABLE",) - elif obj_type in (int, float, str, bool, bytes, type(None)): + elif obj_type in _PRIMITIVE_SIGNATURE_TYPES: return (obj_type.__module__, obj_type.__qualname__, repr(obj)) - elif obj_type is dict: - items = [ - ( - _sanitized_sort_key(k, depth + 1, max_depth), - _sanitized_sort_key(v, depth + 1, max_depth), - ) - for k, v in obj.items() - ] - items.sort() - return ("dict", tuple(items)) - elif obj_type is list: - return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)) - elif obj_type is tuple: - return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)) - elif obj_type is set: - return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))) - elif obj_type is frozenset: - return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))) - else: + elif obj_type not in _CONTAINER_SIGNATURE_TYPES: return (obj_type.__module__, obj_type.__qualname__, "OPAQUE") + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + if obj_id in active: + return ("CYCLE",) -def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None): + active.add(obj_id) + try: + if obj_type is dict: + items = [ + ( + _sanitized_sort_key(k, depth + 1, max_depth, active, memo), + _sanitized_sort_key(v, depth + 1, max_depth, active, memo), + ) + for k, v in obj.items() + ] + items.sort() + result = ("dict", tuple(items)) + elif obj_type is list: + result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)) + elif obj_type is tuple: + result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)) + elif obj_type is set: + result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))) + else: + result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))) + finally: + active.discard(obj_id) + + memo[obj_id] = result + return result + + +def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None): """Normalize signature inputs to safe built-in containers. Preserves built-in container type, replaces opaque runtime values with - Unhashable(), and stops safely on cycles or excessive depth. + Unhashable(), stops safely on cycles or excessive depth, and memoizes + repeated built-in substructures so shared DAG-like inputs do not explode + into repeated recursive work. """ if depth >= max_depth: return Unhashable() - if seen is None: - seen = set() + if active is None: + active = set() + if memo is None: + memo = {} + if budget is None: + budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS} obj_type = type(obj) - if obj_type in (dict, list, tuple, set, frozenset): - obj_id = id(obj) - if obj_id in seen: - return Unhashable() - next_seen = seen | {obj_id} - - if obj_type in (int, float, str, bool, bytes, type(None)): + if obj_type in _PRIMITIVE_SIGNATURE_TYPES: return obj - elif obj_type is dict: - sanitized_items = [ - ( - _sanitize_signature_input(key, depth + 1, max_depth, next_seen), - _sanitize_signature_input(value, depth + 1, max_depth, next_seen), - ) - for key, value in obj.items() - ] - sanitized_items.sort( - key=lambda kv: ( - _sanitized_sort_key(kv[0], depth + 1, max_depth), - _sanitized_sort_key(kv[1], depth + 1, max_depth), - ) - ) - return {key: value for key, value in sanitized_items} - elif obj_type is list: - return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj] - elif obj_type is tuple: - return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj) - elif obj_type is set: - return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj} - elif obj_type is frozenset: - return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj) - else: - # Execution-cache signatures should be built from prompt-safe values. - # If a custom node injects a runtime object here, mark it unhashable so - # the node won't reuse stale cache entries across runs, but do not walk - # the foreign object and risk crashing on custom container semantics. + if obj_type not in _CONTAINER_SIGNATURE_TYPES: return Unhashable() -def to_hashable(obj, depth=0, max_depth=32, seen=None): + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + if obj_id in active: + return Unhashable() + + budget["remaining"] -= 1 + if budget["remaining"] < 0: + return Unhashable() + + active.add(obj_id) + try: + if obj_type is dict: + sort_memo = {} + sanitized_items = [ + ( + _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget), + _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget), + ) + for key, value in obj.items() + ] + ordered_items = [ + ( + ( + _sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo), + _sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo), + ), + (key, value), + ) + for key, value in sanitized_items + ] + ordered_items.sort(key=lambda item: item[0]) + + result = Unhashable() + for index in range(1, len(ordered_items)): + previous_sort_key, previous_item = ordered_items[index - 1] + current_sort_key, current_item = ordered_items[index] + if previous_sort_key == current_sort_key and previous_item != current_item: + break + else: + result = {key: value for _, (key, value) in ordered_items} + elif obj_type is list: + result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj] + elif obj_type is tuple: + result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + elif obj_type is set: + result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj} + else: + result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + finally: + active.discard(obj_id) + + memo[obj_id] = result + return result + + +def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): """Convert sanitized prompt inputs into a stable hashable representation. - Preserves built-in container type and stops safely on cycles or excessive depth. + The input is expected to already be sanitized to plain built-in containers, + but this function still fails safe for anything unexpected. Traversal is + iterative and memoized so shared built-in substructures do not trigger + exponential re-walks during cache-key construction. """ - if depth >= max_depth: - return Unhashable() - - if seen is None: - seen = set() - - # Restrict recursion to plain built-in containers. Some custom nodes insert - # runtime objects into prompt inputs for dynamic graph paths; walking those - # objects as generic Mappings / Sequences is unsafe and can destabilize the - # cache signature builder. obj_type = type(obj) - if obj_type in (int, float, str, bool, bytes, type(None)): + if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable: return obj - - if obj_type in (dict, list, tuple, set, frozenset): - obj_id = id(obj) - if obj_id in seen: - return Unhashable() - seen = seen | {obj_id} - - if obj_type is dict: - return ( - "dict", - frozenset( - ( - to_hashable(k, depth + 1, max_depth, seen), - to_hashable(v, depth + 1, max_depth, seen), - ) - for k, v in obj.items() - ), - ) - elif obj_type is list: - return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) - elif obj_type is tuple: - return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) - elif obj_type is set: - return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) - elif obj_type is frozenset: - return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) - else: + if obj_type not in _CONTAINER_SIGNATURE_TYPES: return Unhashable() + memo = {} + active = set() + sort_memo = {} + processed = 0 + stack = [(obj, False)] + + def resolve_value(value): + """Resolve a child value from the completed memo table when available.""" + value_type = type(value) + if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable: + return value + return memo.get(id(value), Unhashable()) + + def resolve_unordered_values(current, container_tag): + """Resolve a set-like container or fail closed if ordering is ambiguous.""" + ordered_items = [ + (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) + for item in current + ] + ordered_items.sort(key=lambda item: item[0]) + + for index in range(1, len(ordered_items)): + previous_key, previous_value = ordered_items[index - 1] + current_key, current_value = ordered_items[index] + if previous_key == current_key and previous_value != current_value: + return Unhashable() + + return (container_tag, tuple(value for _, value in ordered_items)) + + while stack: + current, expanded = stack.pop() + current_type = type(current) + + if current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable: + continue + if current_type not in _CONTAINER_SIGNATURE_TYPES: + memo[id(current)] = Unhashable() + continue + + current_id = id(current) + if current_id in memo: + continue + + if expanded: + active.discard(current_id) + if current_type is dict: + memo[current_id] = ( + "dict", + tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()), + ) + elif current_type is list: + memo[current_id] = ("list", tuple(resolve_value(item) for item in current)) + elif current_type is tuple: + memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current)) + elif current_type is set: + memo[current_id] = resolve_unordered_values(current, "set") + else: + memo[current_id] = resolve_unordered_values(current, "frozenset") + continue + + if current_id in active: + memo[current_id] = Unhashable() + continue + + processed += 1 + if processed > max_nodes: + return Unhashable() + + active.add(current_id) + stack.append((current, True)) + if current_type is dict: + items = list(current.items()) + for key, value in reversed(items): + stack.append((value, False)) + stack.append((key, False)) + else: + items = list(current) + for item in reversed(items): + stack.append((item, False)) + + return memo.get(id(obj), Unhashable()) + class CacheKeySetID(CacheKeySet): """Cache-key strategy that keys nodes by node id and class type.""" def __init__(self, dynprompt, node_ids, is_changed_cache): @@ -238,6 +341,7 @@ class CacheKeySetInputSignature(CacheKeySet): signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + signature = _sanitize_signature_input(signature) return to_hashable(signature) async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py new file mode 100644 index 000000000..2f088722e --- /dev/null +++ b/tests-unit/execution_test/caching_test.py @@ -0,0 +1,176 @@ +"""Unit tests for cache-signature sanitization and hash conversion hardening.""" + +import asyncio +import importlib +import sys +import types + +import pytest + + +class _DummyNode: + """Minimal node stub used to satisfy cache-signature class lookups.""" + + @staticmethod + def INPUT_TYPES(): + """Return a minimal empty input schema for unit tests.""" + return {"required": {}} + + +class _FakeDynPrompt: + """Small DynamicPrompt stand-in with only the methods these tests need.""" + + def __init__(self, nodes_by_id): + """Store test nodes by id.""" + self._nodes_by_id = nodes_by_id + + def has_node(self, node_id): + """Return whether the fake prompt contains the requested node.""" + return node_id in self._nodes_by_id + + def get_node(self, node_id): + """Return the stored node payload for the requested id.""" + return self._nodes_by_id[node_id] + + +class _FakeIsChangedCache: + """Async stub for `is_changed` lookups used by cache-key generation.""" + + def __init__(self, values): + """Store canned `is_changed` responses keyed by node id.""" + self._values = values + + async def get(self, node_id): + """Return the canned `is_changed` value for a node.""" + return self._values[node_id] + + +class _OpaqueValue: + """Hashable opaque object used to exercise fail-closed unordered hashing paths.""" + + +def _contains_unhashable(value, unhashable_type): + """Return whether a nested built-in structure contains an Unhashable sentinel.""" + if isinstance(value, unhashable_type): + return True + + value_type = type(value) + if value_type is dict: + return any( + _contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type) + for key, item in value.items() + ) + if value_type in (list, tuple, set, frozenset): + return any(_contains_unhashable(item, unhashable_type) for item in value) + return False + + +@pytest.fixture +def caching_module(monkeypatch): + """Import `comfy_execution.caching` with lightweight stub dependencies.""" + torch_module = types.ModuleType("torch") + psutil_module = types.ModuleType("psutil") + nodes_module = types.ModuleType("nodes") + nodes_module.NODE_CLASS_MAPPINGS = {} + graph_module = types.ModuleType("comfy_execution.graph") + + class DynamicPrompt: + """Placeholder graph type so the caching module can import cleanly.""" + + pass + + graph_module.DynamicPrompt = DynamicPrompt + + monkeypatch.setitem(sys.modules, "torch", torch_module) + monkeypatch.setitem(sys.modules, "psutil", psutil_module) + monkeypatch.setitem(sys.modules, "nodes", nodes_module) + monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module) + monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False) + + module = importlib.import_module("comfy_execution.caching") + module = importlib.reload(module) + return module, nodes_module + + +def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_module): + """Shared built-in substructures should sanitize without collapsing to Unhashable.""" + caching, _ = caching_module + shared = [{"value": 1}, {"value": 2}] + + sanitized = caching._sanitize_signature_input([shared, shared]) + + assert isinstance(sanitized, list) + assert sanitized[0] == sanitized[1] + assert sanitized[0][0]["value"] == 1 + assert sanitized[0][1]["value"] == 2 + + +def test_to_hashable_handles_shared_builtin_substructures(caching_module): + """Repeated sanitized content should hash stably for shared substructures.""" + caching, _ = caching_module + shared = [{"value": 1}, {"value": 2}] + + sanitized = caching._sanitize_signature_input([shared, shared]) + hashable = caching.to_hashable(sanitized) + + assert hashable[0] == "list" + assert hashable[1][0] == hashable[1][1] + assert hashable[1][0][0] == "list" + + +def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module): + """Ambiguous dict sort ties should fail closed instead of depending on input order.""" + caching, _ = caching_module + ambiguous = { + _OpaqueValue(): _OpaqueValue(), + _OpaqueValue(): _OpaqueValue(), + } + + sanitized = caching._sanitize_signature_input(ambiguous) + + assert isinstance(sanitized, caching.Unhashable) + + +@pytest.mark.parametrize( + "container_factory", + [ + set, + frozenset, + ], +) +def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory): + """Ambiguous unordered values should fail closed instead of depending on iteration order.""" + caching, _ = caching_module + container = container_factory({_OpaqueValue(), _OpaqueValue()}) + + hashable = caching.to_hashable(container) + + assert isinstance(hashable, caching.Unhashable) + + +def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch): + """Recursive `is_changed` payloads should be sanitized inside the full node signature.""" + caching, nodes_module = caching_module + monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) + + is_changed_value = [] + is_changed_value.append(is_changed_value) + + dynprompt = _FakeDynPrompt( + { + "node": { + "class_type": "UnitTestNode", + "inputs": {"value": 5}, + } + } + ) + key_set = caching.CacheKeySetInputSignature( + dynprompt, + ["node"], + _FakeIsChangedCache({"node": is_changed_value}), + ) + + signature = asyncio.run(key_set.get_node_signature(dynprompt, "node")) + + assert signature[0] == "list" + assert _contains_unhashable(signature, caching.Unhashable)