diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 78212bde3..d3a7ec52a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,7 +1,6 @@ import asyncio import bisect import gc -import itertools import psutil import time import torch @@ -17,6 +16,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} def include_unique_id_in_input(class_type: str) -> bool: + """Return whether a node class includes UNIQUE_ID among its hidden inputs.""" if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -24,52 +24,403 @@ def include_unique_id_in_input(class_type: str) -> bool: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class CacheKeySet(ABC): + """Base helper for building and storing cache keys for prompt nodes.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize cache-key storage for a dynamic prompt execution pass.""" self.keys = {} self.subcache_keys = {} @abstractmethod async def add_keys(self, node_ids): + """Populate cache keys for the provided node ids.""" raise NotImplementedError() def all_node_ids(self): + """Return the set of node ids currently tracked by this key set.""" return set(self.keys.keys()) def get_used_keys(self): + """Return the computed cache keys currently in use.""" return self.keys.values() def get_used_subcache_keys(self): + """Return the computed subcache keys currently in use.""" return self.subcache_keys.values() def get_data_key(self, node_id): + """Return the cache key for a node, if present.""" return self.keys.get(node_id, None) def get_subcache_key(self, node_id): + """Return the subcache key for a node, if present.""" return self.subcache_keys.get(node_id, None) class Unhashable: - def __init__(self): - self.value = float("NaN") + """Hashable identity sentinel for values that cannot be represented safely in cache keys.""" + pass -def to_hashable(obj): - # So that we don't infinitely recurse since frozenset and tuples - # are Sequences. - if isinstance(obj, (int, float, str, bool, bytes, type(None))): + +_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None)) +_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset) +_MAX_SIGNATURE_DEPTH = 32 +_MAX_SIGNATURE_CONTAINER_VISITS = 10_000 +_FAILED_SIGNATURE = object() + + +def _shallow_is_changed_signature(value): + """Sanitize execution-time `is_changed` values without deep recursion.""" + value_type = type(value) + if value_type in _PRIMITIVE_SIGNATURE_TYPES: + return value + if value_type is list or value_type is tuple: + try: + items = tuple(value) + except RuntimeError: + return Unhashable() + if all(type(item) in _PRIMITIVE_SIGNATURE_TYPES for item in items): + container_tag = "is_changed_list" if value_type is list else "is_changed_tuple" + return (container_tag, items) + return Unhashable() + + +def _primitive_signature_sort_key(obj): + """Return a deterministic ordering key for primitive signature values.""" + obj_type = type(obj) + return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj)) + + +def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None): + """Return a deterministic ordering key for sanitized built-in container content.""" + if depth >= max_depth: + return ("MAX_DEPTH",) + + if active is None: + active = set() + if memo is None: + memo = {} + + obj_type = type(obj) + if obj_type is Unhashable: + return ("UNHASHABLE",) + elif obj_type in _PRIMITIVE_SIGNATURE_TYPES: + return (obj_type.__module__, obj_type.__qualname__, repr(obj)) + elif obj_type not in _CONTAINER_SIGNATURE_TYPES: + return (obj_type.__module__, obj_type.__qualname__, "OPAQUE") + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + if obj_id in active: + return ("CYCLE",) + + active.add(obj_id) + try: + if obj_type is dict: + items = [ + ( + _sanitized_sort_key(k, depth + 1, max_depth, active, memo), + _sanitized_sort_key(v, depth + 1, max_depth, active, memo), + ) + for k, v in obj.items() + ] + items.sort() + result = ("dict", tuple(items)) + elif obj_type is list: + result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)) + elif obj_type is tuple: + result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)) + elif obj_type is set: + result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))) + else: + result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))) + finally: + active.discard(obj_id) + + memo[obj_id] = result + return result + + +def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None): + """Canonicalize signature inputs directly into their final hashable form.""" + if depth >= max_depth: + return _FAILED_SIGNATURE + + if active is None: + active = set() + if memo is None: + memo = {} + if budget is None: + budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS} + + obj_type = type(obj) + if obj_type in _PRIMITIVE_SIGNATURE_TYPES: + return obj, _primitive_signature_sort_key(obj) + if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES: + return _FAILED_SIGNATURE + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + if obj_id in active: + return _FAILED_SIGNATURE + + budget["remaining"] -= 1 + if budget["remaining"] < 0: + return _FAILED_SIGNATURE + + active.add(obj_id) + try: + if obj_type is dict: + try: + items = list(obj.items()) + except RuntimeError: + return _FAILED_SIGNATURE + + ordered_items = [] + for key, value in items: + key_result = _signature_to_hashable_impl(key, depth + 1, max_depth, active, memo, budget) + if key_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget) + if value_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + key_value, key_sort = key_result + value_value, value_sort = value_result + ordered_items.append((key_sort, value_sort, key_value, value_value)) + + ordered_items.sort(key=lambda item: (item[0], item[1])) + for index in range(1, len(ordered_items)): + previous_key_sort = ordered_items[index - 1][0] + current_key_sort = ordered_items[index][0] + if previous_key_sort == current_key_sort: + return _FAILED_SIGNATURE + + value = ("dict", tuple((key_value, value_value) for _, _, key_value, value_value in ordered_items)) + sort_key = ("dict", tuple((key_sort, value_sort) for key_sort, value_sort, _, _ in ordered_items)) + elif obj_type is list or obj_type is tuple: + try: + items = list(obj) + except RuntimeError: + return _FAILED_SIGNATURE + + child_results = [] + for item in items: + child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget) + if child_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + child_results.append(child_result) + + container_tag = "list" if obj_type is list else "tuple" + value = (container_tag, tuple(child for child, _ in child_results)) + sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results)) + else: + try: + items = list(obj) + except RuntimeError: + return _FAILED_SIGNATURE + + ordered_items = [] + for item in items: + child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget) + if child_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + child_value, child_sort = child_result + ordered_items.append((child_sort, child_value)) + + ordered_items.sort(key=lambda item: item[0]) + for index in range(1, len(ordered_items)): + previous_sort_key, previous_value = ordered_items[index - 1] + current_sort_key, current_value = ordered_items[index] + if previous_sort_key == current_sort_key and previous_value != current_value: + return _FAILED_SIGNATURE + + container_tag = "set" if obj_type is set else "frozenset" + value = (container_tag, tuple(child_value for _, child_value in ordered_items)) + sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items)) + finally: + active.discard(obj_id) + + memo[obj_id] = (value, sort_key) + return memo[obj_id] + + +def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): + """Build the final cache-signature representation in one fail-closed pass.""" + try: + result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes}) + except RuntimeError: + return Unhashable() + if result is _FAILED_SIGNATURE: + return Unhashable() + return result[0] + + +def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): + """Convert sanitized prompt inputs into a stable hashable representation. + + The input is expected to already be sanitized to plain built-in containers, + but this function still fails safe for anything unexpected. Traversal is + iterative and memoized so shared built-in substructures do not trigger + exponential re-walks during cache-key construction. + """ + obj_type = type(obj) + if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable: return obj - elif isinstance(obj, Mapping): - return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) - elif isinstance(obj, Sequence): - return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) - else: - # TODO - Support other objects like tensors? + if obj_type not in _CONTAINER_SIGNATURE_TYPES: return Unhashable() + memo = {} + active = set() + snapshots = {} + sort_memo = {} + processed = 0 + stack = [(obj, False)] + + def resolve_value(value): + """Resolve a child value from the completed memo table when available.""" + value_type = type(value) + if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable: + return value + return memo.get(id(value), Unhashable()) + + def is_failed(value): + """Return whether a resolved child value represents failed canonicalization.""" + return type(value) is Unhashable + + def resolve_unordered_values(current_items, container_tag): + """Resolve a set-like container or fail closed if ordering is ambiguous.""" + try: + ordered_items = [ + (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) + for item in current_items + ] + if any(is_failed(value) for _, value in ordered_items): + return Unhashable() + ordered_items.sort(key=lambda item: item[0]) + except RuntimeError: + return Unhashable() + + for index in range(1, len(ordered_items)): + previous_key, previous_value = ordered_items[index - 1] + current_key, current_value = ordered_items[index] + if previous_key == current_key and previous_value != current_value: + return Unhashable() + + return (container_tag, tuple(value for _, value in ordered_items)) + + while stack: + current, expanded = stack.pop() + current_type = type(current) + + if current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable: + continue + if current_type not in _CONTAINER_SIGNATURE_TYPES: + memo[id(current)] = Unhashable() + continue + + current_id = id(current) + if current_id in memo: + continue + + if expanded: + active.discard(current_id) + try: + if current_type is dict: + items = snapshots.pop(current_id, None) + if items is None: + items = list(current.items()) + ordered_items = [ + (_sanitized_sort_key(k, memo=sort_memo), resolve_value(k), resolve_value(v)) + for k, v in items + ] + if any(is_failed(key) or is_failed(value) for _, key, value in ordered_items): + memo[current_id] = Unhashable() + continue + ordered_items.sort(key=lambda item: item[0]) + for index in range(1, len(ordered_items)): + if ordered_items[index - 1][0] == ordered_items[index][0]: + memo[current_id] = Unhashable() + break + else: + memo[current_id] = ( + "dict", + tuple((key, value) for _, key, value in ordered_items), + ) + elif current_type is list: + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + resolved_items = tuple(resolve_value(item) for item in items) + if any(is_failed(item) for item in resolved_items): + memo[current_id] = Unhashable() + else: + memo[current_id] = ("list", resolved_items) + elif current_type is tuple: + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + resolved_items = tuple(resolve_value(item) for item in items) + if any(is_failed(item) for item in resolved_items): + memo[current_id] = Unhashable() + else: + memo[current_id] = ("tuple", resolved_items) + elif current_type is set: + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + memo[current_id] = resolve_unordered_values(items, "set") + else: + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + memo[current_id] = resolve_unordered_values(items, "frozenset") + except RuntimeError: + memo[current_id] = Unhashable() + continue + + if current_id in active: + memo[current_id] = Unhashable() + continue + + processed += 1 + if processed > max_nodes: + return Unhashable() + + active.add(current_id) + stack.append((current, True)) + if current_type is dict: + try: + items = list(current.items()) + snapshots[current_id] = items + except RuntimeError: + memo[current_id] = Unhashable() + active.discard(current_id) + continue + for key, value in reversed(items): + stack.append((value, False)) + stack.append((key, False)) + else: + try: + items = list(current) + snapshots[current_id] = items + except RuntimeError: + memo[current_id] = Unhashable() + active.discard(current_id) + continue + for item in reversed(items): + stack.append((item, False)) + + return memo.get(id(obj), Unhashable()) + class CacheKeySetID(CacheKeySet): + """Cache-key strategy that keys nodes by node id and class type.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize identity-based cache keys for the supplied dynamic prompt.""" super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt async def add_keys(self, node_ids): + """Populate identity-based keys for nodes that exist in the dynamic prompt.""" for node_id in node_ids: if node_id in self.keys: continue @@ -80,15 +431,19 @@ class CacheKeySetID(CacheKeySet): self.subcache_keys[node_id] = (node_id, node["class_type"]) class CacheKeySetInputSignature(CacheKeySet): + """Cache-key strategy that hashes a node's immediate inputs plus ancestor references.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize input-signature-based cache keys for the supplied dynamic prompt.""" super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache def include_node_id_in_input(self) -> bool: + """Return whether node ids should be included in computed input signatures.""" return False async def add_keys(self, node_ids): + """Populate input-signature-based keys for nodes in the dynamic prompt.""" for node_id in node_ids: if node_id in self.keys: continue @@ -99,21 +454,28 @@ class CacheKeySetInputSignature(CacheKeySet): self.subcache_keys[node_id] = (node_id, node["class_type"]) async def get_node_signature(self, dynprompt, node_id): + """Build the full cache signature for a node and its ordered ancestors.""" signature = [] ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) for ancestor_id in ancestors: signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) - return to_hashable(signature) + return _signature_to_hashable(signature) async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + """Build the immediate cache-signature fragment for a node. + + Link inputs are reduced to ancestor references here, while non-link + values are appended as-is. Full canonicalization happens later in + `get_node_signature()` via `_signature_to_hashable()`. + """ if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, await self.is_changed_cache.get(node_id)] + signature = [class_type, _shallow_is_changed_signature(await self.is_changed_cache.get(node_id))] if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) inputs = node["inputs"] @@ -129,12 +491,14 @@ class CacheKeySetInputSignature(CacheKeySet): # This function returns a list of all ancestors of the given node. The order of the list is # deterministic based on which specific inputs the ancestor is connected by. def get_ordered_ancestry(self, dynprompt, node_id): + """Return ancestors in deterministic traversal order and their index mapping.""" ancestors = [] order_mapping = {} self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) return ancestors, order_mapping def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + """Recursively collect ancestors in input order without revisiting prior nodes.""" if not dynprompt.has_node(node_id): return inputs = dynprompt.get_node(node_id)["inputs"] diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py index 496d2c634..57b4ef36e 100644 --- a/comfy_execution/graph_utils.py +++ b/comfy_execution/graph_utils.py @@ -1,11 +1,17 @@ def is_link(obj): - if not isinstance(obj, list): + """Return whether obj is a plain prompt link of the form [node_id, output_index].""" + # Prompt links produced by the frontend / GraphBuilder are plain Python + # lists in the form [node_id, output_index]. Some custom-node paths can + # inject foreign runtime objects into prompt inputs during on-prompt graph + # rewriting or subgraph construction. Be strict here so cache signature + # building never tries to treat list-like proxy objects as links. + if type(obj) is not list: return False if len(obj) != 2: return False - if not isinstance(obj[0], str): + if type(obj[0]) is not str: return False - if not isinstance(obj[1], int) and not isinstance(obj[1], float): + if type(obj[1]) is not int: return False return True diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py new file mode 100644 index 000000000..943f72586 --- /dev/null +++ b/tests-unit/execution_test/caching_test.py @@ -0,0 +1,388 @@ +"""Unit tests for cache-signature canonicalization hardening.""" + +import asyncio +import importlib +import sys +import types + +import pytest + + +class _DummyNode: + """Minimal node stub used to satisfy cache-signature class lookups.""" + + @staticmethod + def INPUT_TYPES(): + """Return a minimal empty input schema for unit tests.""" + return {"required": {}} + + +class _FakeDynPrompt: + """Small DynamicPrompt stand-in with only the methods these tests need.""" + + def __init__(self, nodes_by_id): + """Store test nodes by id.""" + self._nodes_by_id = nodes_by_id + + def has_node(self, node_id): + """Return whether the fake prompt contains the requested node.""" + return node_id in self._nodes_by_id + + def get_node(self, node_id): + """Return the stored node payload for the requested id.""" + return self._nodes_by_id[node_id] + + +class _FakeIsChangedCache: + """Async stub for `is_changed` lookups used by cache-key generation.""" + + def __init__(self, values): + """Store canned `is_changed` responses keyed by node id.""" + self._values = values + + async def get(self, node_id): + """Return the canned `is_changed` value for a node.""" + return self._values[node_id] + + +class _OpaqueValue: + """Hashable opaque object used to exercise fail-closed unordered hashing paths.""" + + +@pytest.fixture +def caching_module(monkeypatch): + """Import `comfy_execution.caching` with lightweight stub dependencies.""" + torch_module = types.ModuleType("torch") + psutil_module = types.ModuleType("psutil") + nodes_module = types.ModuleType("nodes") + nodes_module.NODE_CLASS_MAPPINGS = {} + graph_module = types.ModuleType("comfy_execution.graph") + + class DynamicPrompt: + """Placeholder graph type so the caching module can import cleanly.""" + + pass + + graph_module.DynamicPrompt = DynamicPrompt + + monkeypatch.setitem(sys.modules, "torch", torch_module) + monkeypatch.setitem(sys.modules, "psutil", psutil_module) + monkeypatch.setitem(sys.modules, "nodes", nodes_module) + monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module) + monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False) + + module = importlib.import_module("comfy_execution.caching") + module = importlib.reload(module) + return module, nodes_module + + +def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module): + """Shared built-in substructures should canonicalize without collapsing to Unhashable.""" + caching, _ = caching_module + shared = [{"value": 1}, {"value": 2}] + + signature = caching._signature_to_hashable([shared, shared]) + + assert signature[0] == "list" + assert signature[1][0] == signature[1][1] + assert signature[1][0][0] == "list" + assert signature[1][0][1][0] == ("dict", (("value", 1),)) + assert signature[1][0][1][1] == ("dict", (("value", 2),)) + + +def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module): + """Opaque values should collapse the full signature to Unhashable immediately.""" + caching, _ = caching_module + + signature = caching._signature_to_hashable(["safe", object()]) + + assert isinstance(signature, caching.Unhashable) + + +def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch): + """Once canonicalization fails, later recursive descent should stop immediately.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = object() + marker_seen = False + + def tracking_canonicalize(obj, *args, **kwargs): + """Track whether recursion reaches the nested marker after failure.""" + nonlocal marker_seen + if obj is marker: + marker_seen = True + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize) + + signature = caching._signature_to_hashable([object(), [marker]]) + + assert isinstance(signature, caching.Unhashable) + assert marker_seen is False + + +def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch): + """List canonicalization should read a point-in-time snapshot before recursive descent.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = ("marker",) + values = [marker, 2] + + def mutating_canonicalize(obj, *args, **kwargs): + """Mutate the live list during recursion to verify snapshot-based traversal.""" + if obj is marker: + values[1] = 3 + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize) + + signature = caching._signature_to_hashable(values) + + assert signature == ("list", (("tuple", ("marker",)), 2)) + assert values[1] == 3 + + +def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch): + """Dict canonicalization should read a point-in-time snapshot before recursive descent.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = ("marker",) + values = {"first": marker, "second": 2} + + def mutating_canonicalize(obj, *args, **kwargs): + """Mutate the live dict during recursion to verify snapshot-based traversal.""" + if obj is marker: + values["second"] = 3 + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize) + + signature = caching._signature_to_hashable(values) + + assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2))) + assert values["second"] == 3 + + +@pytest.mark.parametrize( + "container_factory", + [ + lambda marker: [marker], + lambda marker: (marker,), + lambda marker: {marker}, + lambda marker: frozenset({marker}), + lambda marker: {marker: "value"}, + ], +) +def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade canonicalization to Unhashable.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = object() + + def raising_canonicalize(obj, *args, **kwargs): + """Raise a traversal RuntimeError for the marker value and delegate otherwise.""" + if obj is marker: + raise RuntimeError("container changed during iteration") + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize) + + signature = caching._signature_to_hashable(container_factory(marker)) + + assert isinstance(signature, caching.Unhashable) + + +def test_to_hashable_handles_shared_builtin_substructures(caching_module): + """The legacy helper should still hash sanitized built-ins stably when used directly.""" + caching, _ = caching_module + shared = [{"value": 1}, {"value": 2}] + + sanitized = [shared, shared] + hashable = caching.to_hashable(sanitized) + + assert hashable[0] == "list" + assert hashable[1][0] == hashable[1][1] + assert hashable[1][0][0] == "list" + + +def test_to_hashable_fails_closed_for_ordered_container_with_opaque_child(caching_module): + """Ordered containers should fail closed when a child cannot be canonicalized.""" + caching, _ = caching_module + + result = caching.to_hashable([object()]) + + assert isinstance(result, caching.Unhashable) + + +def test_to_hashable_canonicalizes_dict_insertion_order(caching_module): + """Dicts with the same content should hash identically regardless of insertion order.""" + caching, _ = caching_module + + first = {"b": 2, "a": 1} + second = {"a": 1, "b": 2} + + assert caching.to_hashable(first) == ("dict", (("a", 1), ("b", 2))) + assert caching.to_hashable(first) == caching.to_hashable(second) + + +@pytest.mark.parametrize( + "container_factory", + [ + set, + frozenset, + ], +) +def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade unordered hash conversion to Unhashable.""" + caching, _ = caching_module + + def raising_sort_key(obj, *args, **kwargs): + """Raise a traversal RuntimeError while unordered values are canonicalized.""" + raise RuntimeError("container changed during iteration") + + monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key) + + hashable = caching.to_hashable(container_factory({"value"})) + + assert isinstance(hashable, caching.Unhashable) + + +def test_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module): + """Ambiguous dict key ordering should fail closed instead of using insertion order.""" + caching, _ = caching_module + ambiguous = { + _OpaqueValue(): 1, + _OpaqueValue(): 2, + } + + hashable = caching.to_hashable(ambiguous) + + assert isinstance(hashable, caching.Unhashable) + + +def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module): + """Ambiguous dict sort ties should fail closed instead of depending on input order.""" + caching, _ = caching_module + ambiguous = { + _OpaqueValue(): _OpaqueValue(), + _OpaqueValue(): _OpaqueValue(), + } + + sanitized = caching._signature_to_hashable(ambiguous) + + assert isinstance(sanitized, caching.Unhashable) + + +def test_signature_to_hashable_fails_closed_on_dict_key_sort_collisions_even_with_distinct_values(caching_module, monkeypatch): + """Different values must not mask dict key-sort collisions during canonicalization.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + key_a = object() + key_b = object() + + def colliding_key_canonicalize(obj, *args, **kwargs): + """Force two distinct raw keys to share the same canonical sort key.""" + if obj is key_a: + return ("key-a", ("COLLIDE",)) + if obj is key_b: + return ("key-b", ("COLLIDE",)) + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", colliding_key_canonicalize) + + sanitized = caching._signature_to_hashable({key_a: 1, key_b: 2}) + + assert isinstance(sanitized, caching.Unhashable) + + +@pytest.mark.parametrize( + "container_factory", + [ + set, + frozenset, + ], +) +def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory): + """Ambiguous unordered values should fail closed instead of depending on iteration order.""" + caching, _ = caching_module + container = container_factory({_OpaqueValue(), _OpaqueValue()}) + + hashable = caching.to_hashable(container) + + assert isinstance(hashable, caching.Unhashable) + + +def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch): + """Tainted full signatures should fail closed before `to_hashable()` runs.""" + caching, nodes_module = caching_module + monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) + monkeypatch.setattr( + caching, + "to_hashable", + lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"), + ) + + is_changed_value = [] + is_changed_value.append(is_changed_value) + + dynprompt = _FakeDynPrompt( + { + "node": { + "class_type": "UnitTestNode", + "inputs": {"value": 5}, + } + } + ) + key_set = caching.CacheKeySetInputSignature( + dynprompt, + ["node"], + _FakeIsChangedCache({"node": is_changed_value}), + ) + + signature = asyncio.run(key_set.get_node_signature(dynprompt, "node")) + + assert isinstance(signature, caching.Unhashable) + + +def test_shallow_is_changed_signature_accepts_primitive_lists(caching_module): + """Primitive-only `is_changed` lists should stay hashable without deep descent.""" + caching, _ = caching_module + + sanitized = caching._shallow_is_changed_signature([1, "two", None, True]) + + assert sanitized == ("is_changed_list", (1, "two", None, True)) + + +def test_shallow_is_changed_signature_fails_closed_on_nested_containers(caching_module): + """Nested containers from `is_changed` should be rejected immediately.""" + caching, _ = caching_module + + sanitized = caching._shallow_is_changed_signature([1, ["nested"]]) + + assert isinstance(sanitized, caching.Unhashable) + + +def test_get_immediate_node_signature_marks_recursive_is_changed_unhashable(caching_module, monkeypatch): + """Recursive `is_changed` payloads should be cut off before signature canonicalization.""" + caching, nodes_module = caching_module + monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) + + is_changed_value = [] + is_changed_value.append(is_changed_value) + dynprompt = _FakeDynPrompt( + { + "node": { + "class_type": "UnitTestNode", + "inputs": {"value": 5}, + } + } + ) + key_set = caching.CacheKeySetInputSignature( + dynprompt, + ["node"], + _FakeIsChangedCache({"node": is_changed_value}), + ) + + signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "node", {})) + + assert isinstance(signature[1], caching.Unhashable)