diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 78212bde3..fecf54d1e 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,7 +1,6 @@ import asyncio import bisect import gc -import itertools import psutil import time import torch @@ -17,6 +16,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} def include_unique_id_in_input(class_type: str) -> bool: + """Return whether a node class includes UNIQUE_ID among its hidden inputs.""" if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -24,52 +24,410 @@ def include_unique_id_in_input(class_type: str) -> bool: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class CacheKeySet(ABC): + """Base helper for building and storing cache keys for prompt nodes.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize cache-key storage for a dynamic prompt execution pass.""" self.keys = {} self.subcache_keys = {} @abstractmethod async def add_keys(self, node_ids): + """Populate cache keys for the provided node ids.""" raise NotImplementedError() def all_node_ids(self): + """Return the set of node ids currently tracked by this key set.""" return set(self.keys.keys()) def get_used_keys(self): + """Return the computed cache keys currently in use.""" return self.keys.values() def get_used_subcache_keys(self): + """Return the computed subcache keys currently in use.""" return self.subcache_keys.values() def get_data_key(self, node_id): + """Return the cache key for a node, if present.""" return self.keys.get(node_id, None) def get_subcache_key(self, node_id): + """Return the subcache key for a node, if present.""" return self.subcache_keys.get(node_id, None) class Unhashable: - def __init__(self): - self.value = float("NaN") + """Hashable identity sentinel for values that cannot be represented safely in cache keys.""" + pass -def to_hashable(obj): - # So that we don't infinitely recurse since frozenset and tuples - # are Sequences. - if isinstance(obj, (int, float, str, bool, bytes, type(None))): + +_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None)) +_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset) +_MAX_SIGNATURE_DEPTH = 32 +_MAX_SIGNATURE_CONTAINER_VISITS = 10_000 +_FAILED_SIGNATURE = object() + + +def _shallow_is_changed_signature(value): + """Sanitize execution-time `is_changed` values with a small fail-closed budget.""" + value_type = type(value) + if value_type in _PRIMITIVE_SIGNATURE_TYPES: + return value + + canonical = to_hashable(value, max_nodes=64) + if type(canonical) is Unhashable: + return canonical + + if value_type is list or value_type is tuple: + container_tag = "is_changed_list" if value_type is list else "is_changed_tuple" + return (container_tag, canonical[1]) + + return canonical + + +def _primitive_signature_sort_key(obj): + """Return a deterministic ordering key for primitive signature values.""" + obj_type = type(obj) + return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj)) + + +def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None): + """Return a deterministic ordering key for sanitized built-in container content.""" + if depth >= max_depth: + return ("MAX_DEPTH",) + + if active is None: + active = set() + if memo is None: + memo = {} + + obj_type = type(obj) + if obj_type is Unhashable: + return ("UNHASHABLE",) + elif obj_type in _PRIMITIVE_SIGNATURE_TYPES: + return (obj_type.__module__, obj_type.__qualname__, repr(obj)) + elif obj_type not in _CONTAINER_SIGNATURE_TYPES: + return (obj_type.__module__, obj_type.__qualname__, "OPAQUE") + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + if obj_id in active: + return ("CYCLE",) + + active.add(obj_id) + try: + if obj_type is dict: + items = [ + ( + _sanitized_sort_key(k, depth + 1, max_depth, active, memo), + _sanitized_sort_key(v, depth + 1, max_depth, active, memo), + ) + for k, v in obj.items() + ] + items.sort() + result = ("dict", tuple(items)) + elif obj_type is list: + result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)) + elif obj_type is tuple: + result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)) + elif obj_type is set: + result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))) + else: + result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))) + finally: + active.discard(obj_id) + + memo[obj_id] = result + return result + + +def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None): + """Canonicalize signature inputs directly into their final hashable form.""" + if depth >= max_depth: + return _FAILED_SIGNATURE + + if active is None: + active = set() + if memo is None: + memo = {} + if budget is None: + budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS} + + obj_type = type(obj) + if obj_type in _PRIMITIVE_SIGNATURE_TYPES: + return obj, _primitive_signature_sort_key(obj) + if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES: + return _FAILED_SIGNATURE + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + if obj_id in active: + return _FAILED_SIGNATURE + + budget["remaining"] -= 1 + if budget["remaining"] < 0: + return _FAILED_SIGNATURE + + active.add(obj_id) + try: + if obj_type is dict: + try: + items = list(obj.items()) + except RuntimeError: + return _FAILED_SIGNATURE + + ordered_items = [] + for key, value in items: + if type(key) not in _PRIMITIVE_SIGNATURE_TYPES: + return _FAILED_SIGNATURE + key_result = (key, _primitive_signature_sort_key(key)) + value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget) + if value_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + key_value, key_sort = key_result + value_value, value_sort = value_result + ordered_items.append((key_sort, value_sort, key_value, value_value)) + + ordered_items.sort(key=lambda item: (item[0], item[1])) + for index in range(1, len(ordered_items)): + previous_key_sort = ordered_items[index - 1][0] + current_key_sort = ordered_items[index][0] + if previous_key_sort == current_key_sort: + return _FAILED_SIGNATURE + + value = ("dict", tuple((key_value, value_value) for _, _, key_value, value_value in ordered_items)) + sort_key = ("dict", tuple((key_sort, value_sort) for key_sort, value_sort, _, _ in ordered_items)) + elif obj_type is list or obj_type is tuple: + try: + items = list(obj) + except RuntimeError: + return _FAILED_SIGNATURE + + child_results = [] + for item in items: + child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget) + if child_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + child_results.append(child_result) + + container_tag = "list" if obj_type is list else "tuple" + value = (container_tag, tuple(child for child, _ in child_results)) + sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results)) + else: + try: + items = list(obj) + except RuntimeError: + return _FAILED_SIGNATURE + + ordered_items = [] + for item in items: + child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget) + if child_result is _FAILED_SIGNATURE: + return _FAILED_SIGNATURE + child_value, child_sort = child_result + ordered_items.append((child_sort, child_value)) + + ordered_items.sort(key=lambda item: item[0]) + for index in range(1, len(ordered_items)): + previous_sort_key, previous_value = ordered_items[index - 1] + current_sort_key, current_value = ordered_items[index] + if previous_sort_key == current_sort_key and previous_value != current_value: + return _FAILED_SIGNATURE + + container_tag = "set" if obj_type is set else "frozenset" + value = (container_tag, tuple(child_value for _, child_value in ordered_items)) + sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items)) + finally: + active.discard(obj_id) + + memo[obj_id] = (value, sort_key) + return memo[obj_id] + + +def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): + """Build the final cache-signature representation in one fail-closed pass.""" + try: + result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes}) + except RuntimeError: + return Unhashable() + if result is _FAILED_SIGNATURE: + return Unhashable() + return result[0] + + +def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): + """Convert sanitized prompt inputs into a stable hashable representation. + + The input is expected to already be sanitized to plain built-in containers, + but this function still fails safe for anything unexpected. Traversal is + iterative and memoized so shared built-in substructures do not trigger + exponential re-walks during cache-key construction. + """ + obj_type = type(obj) + if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable: return obj - elif isinstance(obj, Mapping): - return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) - elif isinstance(obj, Sequence): - return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) - else: - # TODO - Support other objects like tensors? + if obj_type not in _CONTAINER_SIGNATURE_TYPES: return Unhashable() + memo = {} + active = set() + snapshots = {} + sort_memo = {} + processed = 0 + # Keep traversal state separate from container snapshots/results. + work_stack = [(obj, False)] + + def resolve_value(value): + """Resolve a child value from the completed memo table when available.""" + value_type = type(value) + if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable: + return value + return memo.get(id(value), Unhashable()) + + def is_failed(value): + """Return whether a resolved child value represents failed canonicalization.""" + return type(value) is Unhashable + + def resolve_unordered_values(current_items, container_tag): + """Resolve a set-like container or fail closed if ordering is ambiguous.""" + try: + ordered_items = [ + (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) + for item in current_items + ] + if any(is_failed(value) for _, value in ordered_items): + return Unhashable() + ordered_items.sort(key=lambda item: item[0]) + except RuntimeError: + return Unhashable() + + for index in range(1, len(ordered_items)): + previous_key, previous_value = ordered_items[index - 1] + current_key, current_value = ordered_items[index] + if previous_key == current_key and previous_value != current_value: + return Unhashable() + + return (container_tag, tuple(value for _, value in ordered_items)) + + while work_stack: + entry = work_stack.pop() + if len(entry) == 3: + _, current_id, current_type = entry + current = None + expanded = True + else: + current, expanded = entry + current_type = type(current) + current_id = id(current) + + if not expanded and (current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable): + continue + if not expanded and current_type not in _CONTAINER_SIGNATURE_TYPES: + memo[current_id] = Unhashable() + continue + + if current_id in memo: + continue + + if expanded: + active.discard(current_id) + try: + items = snapshots.pop(current_id, None) + if items is None: + memo[current_id] = Unhashable() + continue + + if current_type is dict: + ordered_items = [ + (_sanitized_sort_key(k, memo=sort_memo), k, resolve_value(v)) + for k, v in items + ] + if any(type(key) not in _PRIMITIVE_SIGNATURE_TYPES or is_failed(value) for _, key, value in ordered_items): + memo[current_id] = Unhashable() + continue + ordered_items.sort(key=lambda item: item[0]) + for index in range(1, len(ordered_items)): + if ordered_items[index - 1][0] == ordered_items[index][0]: + memo[current_id] = Unhashable() + break + else: + memo[current_id] = ( + "dict", + tuple((key, value) for _, key, value in ordered_items), + ) + elif current_type is list: + resolved_items = tuple(resolve_value(item) for item in items) + if any(is_failed(item) for item in resolved_items): + memo[current_id] = Unhashable() + else: + memo[current_id] = ("list", resolved_items) + elif current_type is tuple: + resolved_items = tuple(resolve_value(item) for item in items) + if any(is_failed(item) for item in resolved_items): + memo[current_id] = Unhashable() + else: + memo[current_id] = ("tuple", resolved_items) + elif current_type is set: + memo[current_id] = resolve_unordered_values(items, "set") + else: + memo[current_id] = resolve_unordered_values(items, "frozenset") + except RuntimeError: + memo[current_id] = Unhashable() + continue + + if current_id in active: + memo[current_id] = Unhashable() + continue + + processed += 1 + if processed > max_nodes: + return Unhashable() + + active.add(current_id) + if current_type is dict: + try: + items = list(current.items()) + snapshots[current_id] = items + except RuntimeError: + memo[current_id] = Unhashable() + active.discard(current_id) + continue + for key, value in items: + if type(key) not in _PRIMITIVE_SIGNATURE_TYPES: + snapshots.pop(current_id, None) + memo[current_id] = Unhashable() + active.discard(current_id) + break + else: + work_stack.append(("EXPANDED", current_id, current_type)) + for _, value in reversed(items): + work_stack.append((value, False)) + continue + continue + else: + try: + items = list(current) + snapshots[current_id] = items + except RuntimeError: + memo[current_id] = Unhashable() + active.discard(current_id) + continue + work_stack.append(("EXPANDED", current_id, current_type)) + for item in reversed(items): + work_stack.append((item, False)) + + return memo.get(id(obj), Unhashable()) + class CacheKeySetID(CacheKeySet): + """Cache-key strategy that keys nodes by node id and class type.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize identity-based cache keys for the supplied dynamic prompt.""" super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt async def add_keys(self, node_ids): + """Populate identity-based keys for nodes that exist in the dynamic prompt.""" for node_id in node_ids: if node_id in self.keys: continue @@ -80,15 +438,19 @@ class CacheKeySetID(CacheKeySet): self.subcache_keys[node_id] = (node_id, node["class_type"]) class CacheKeySetInputSignature(CacheKeySet): + """Cache-key strategy that hashes a node's immediate inputs plus ancestor references.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize input-signature-based cache keys for the supplied dynamic prompt.""" super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache def include_node_id_in_input(self) -> bool: + """Return whether node ids should be included in computed input signatures.""" return False async def add_keys(self, node_ids): + """Populate input-signature-based keys for nodes in the dynamic prompt.""" for node_id in node_ids: if node_id in self.keys: continue @@ -99,21 +461,37 @@ class CacheKeySetInputSignature(CacheKeySet): self.subcache_keys[node_id] = (node_id, node["class_type"]) async def get_node_signature(self, dynprompt, node_id): + """Build the full cache signature for a node and its ordered ancestors.""" signature = [] ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) - signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + immediate = await self.get_immediate_node_signature(dynprompt, node_id, order_mapping) + if type(immediate) is Unhashable: + return immediate + signature.append(immediate) for ancestor_id in ancestors: - signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) - return to_hashable(signature) + immediate = await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping) + if type(immediate) is Unhashable: + return immediate + signature.append(immediate) + return tuple(signature) async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + """Build the immediate cache-signature fragment for a node. + + Link inputs are reduced to ancestor references here. Non-link values + are canonicalized or failed closed before being appended so the final + node signature is assembled from already-hashable fragments. + """ if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. - return [float("NaN")] + return Unhashable() node = dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, await self.is_changed_cache.get(node_id)] + is_changed_signature = _shallow_is_changed_signature(await self.is_changed_cache.get(node_id)) + if type(is_changed_signature) is Unhashable: + return is_changed_signature + signature = [class_type, is_changed_signature] if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) inputs = node["inputs"] @@ -123,18 +501,23 @@ class CacheKeySetInputSignature(CacheKeySet): ancestor_index = ancestor_order_mapping[ancestor_id] signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: - signature.append((key, inputs[key])) - return signature + value_signature = to_hashable(inputs[key]) + if type(value_signature) is Unhashable: + return value_signature + signature.append((key, value_signature)) + return tuple(signature) # This function returns a list of all ancestors of the given node. The order of the list is # deterministic based on which specific inputs the ancestor is connected by. def get_ordered_ancestry(self, dynprompt, node_id): + """Return ancestors in deterministic traversal order and their index mapping.""" ancestors = [] order_mapping = {} self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) return ancestors, order_mapping def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + """Recursively collect ancestors in input order without revisiting prior nodes.""" if not dynprompt.has_node(node_id): return inputs = dynprompt.get_node(node_id)["inputs"] diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py index 496d2c634..57b4ef36e 100644 --- a/comfy_execution/graph_utils.py +++ b/comfy_execution/graph_utils.py @@ -1,11 +1,17 @@ def is_link(obj): - if not isinstance(obj, list): + """Return whether obj is a plain prompt link of the form [node_id, output_index].""" + # Prompt links produced by the frontend / GraphBuilder are plain Python + # lists in the form [node_id, output_index]. Some custom-node paths can + # inject foreign runtime objects into prompt inputs during on-prompt graph + # rewriting or subgraph construction. Be strict here so cache signature + # building never tries to treat list-like proxy objects as links. + if type(obj) is not list: return False if len(obj) != 2: return False - if not isinstance(obj[0], str): + if type(obj[0]) is not str: return False - if not isinstance(obj[1], int) and not isinstance(obj[1], float): + if type(obj[1]) is not int: return False return True diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py new file mode 100644 index 000000000..a21dea628 --- /dev/null +++ b/tests-unit/execution_test/caching_test.py @@ -0,0 +1,473 @@ +"""Unit tests for cache-signature canonicalization hardening.""" + +import asyncio +import importlib +import sys +import types + +import pytest + + +class _DummyNode: + """Minimal node stub used to satisfy cache-signature class lookups.""" + + @staticmethod + def INPUT_TYPES(): + """Return a minimal empty input schema for unit tests.""" + return {"required": {}} + + +class _FakeDynPrompt: + """Small DynamicPrompt stand-in with only the methods these tests need.""" + + def __init__(self, nodes_by_id): + """Store test nodes by id.""" + self._nodes_by_id = nodes_by_id + + def has_node(self, node_id): + """Return whether the fake prompt contains the requested node.""" + return node_id in self._nodes_by_id + + def get_node(self, node_id): + """Return the stored node payload for the requested id.""" + return self._nodes_by_id[node_id] + + +class _FakeIsChangedCache: + """Async stub for `is_changed` lookups used by cache-key generation.""" + + def __init__(self, values): + """Store canned `is_changed` responses keyed by node id.""" + self._values = values + + async def get(self, node_id): + """Return the canned `is_changed` value for a node.""" + return self._values[node_id] + + +class _OpaqueValue: + """Hashable opaque object used to exercise fail-closed unordered hashing paths.""" + + +@pytest.fixture +def caching_module(monkeypatch): + """Import `comfy_execution.caching` with lightweight stub dependencies.""" + torch_module = types.ModuleType("torch") + psutil_module = types.ModuleType("psutil") + nodes_module = types.ModuleType("nodes") + nodes_module.NODE_CLASS_MAPPINGS = {} + graph_module = types.ModuleType("comfy_execution.graph") + + class DynamicPrompt: + """Placeholder graph type so the caching module can import cleanly.""" + + pass + + graph_module.DynamicPrompt = DynamicPrompt + + monkeypatch.setitem(sys.modules, "torch", torch_module) + monkeypatch.setitem(sys.modules, "psutil", psutil_module) + monkeypatch.setitem(sys.modules, "nodes", nodes_module) + monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module) + monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False) + + module = importlib.import_module("comfy_execution.caching") + module = importlib.reload(module) + return module, nodes_module + + +def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module): + """Shared built-in substructures should canonicalize without collapsing to Unhashable.""" + caching, _ = caching_module + shared = [{"value": 1}, {"value": 2}] + + signature = caching._signature_to_hashable([shared, shared]) + + assert signature[0] == "list" + assert signature[1][0] == signature[1][1] + assert signature[1][0][0] == "list" + assert signature[1][0][1][0] == ("dict", (("value", 1),)) + assert signature[1][0][1][1] == ("dict", (("value", 2),)) + + +def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module): + """Opaque values should collapse the full signature to Unhashable immediately.""" + caching, _ = caching_module + + signature = caching._signature_to_hashable(["safe", object()]) + + assert isinstance(signature, caching.Unhashable) + + +def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch): + """Once canonicalization fails, later recursive descent should stop immediately.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = object() + marker_seen = False + + def tracking_canonicalize(obj, *args, **kwargs): + """Track whether recursion reaches the nested marker after failure.""" + nonlocal marker_seen + if obj is marker: + marker_seen = True + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize) + + signature = caching._signature_to_hashable([object(), [marker]]) + + assert isinstance(signature, caching.Unhashable) + assert marker_seen is False + + +def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch): + """List canonicalization should read a point-in-time snapshot before recursive descent.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = ("marker",) + values = [marker, 2] + + def mutating_canonicalize(obj, *args, **kwargs): + """Mutate the live list during recursion to verify snapshot-based traversal.""" + if obj is marker: + values[1] = 3 + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize) + + signature = caching._signature_to_hashable(values) + + assert signature == ("list", (("tuple", ("marker",)), 2)) + assert values[1] == 3 + + +def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch): + """Dict canonicalization should read a point-in-time snapshot before recursive descent.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = ("marker",) + values = {"first": marker, "second": 2} + + def mutating_canonicalize(obj, *args, **kwargs): + """Mutate the live dict during recursion to verify snapshot-based traversal.""" + if obj is marker: + values["second"] = 3 + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize) + + signature = caching._signature_to_hashable(values) + + assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2))) + assert values["second"] == 3 + + +@pytest.mark.parametrize( + "container_factory", + [ + lambda marker: [marker], + lambda marker: (marker,), + lambda marker: {marker}, + lambda marker: frozenset({marker}), + lambda marker: {"key": marker}, + ], +) +def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade canonicalization to Unhashable.""" + caching, _ = caching_module + original = caching._signature_to_hashable_impl + marker = object() + + def raising_canonicalize(obj, *args, **kwargs): + """Raise a traversal RuntimeError for the marker value and delegate otherwise.""" + if obj is marker: + raise RuntimeError("container changed during iteration") + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize) + + signature = caching._signature_to_hashable(container_factory(marker)) + + assert isinstance(signature, caching.Unhashable) + + +def test_to_hashable_handles_shared_builtin_substructures(caching_module): + """The legacy helper should still hash sanitized built-ins stably when used directly.""" + caching, _ = caching_module + shared = [{"value": 1}, {"value": 2}] + + sanitized = [shared, shared] + hashable = caching.to_hashable(sanitized) + + assert hashable[0] == "list" + assert hashable[1][0] == hashable[1][1] + assert hashable[1][0][0] == "list" + + +def test_to_hashable_uses_parent_snapshot_during_expanded_phase(caching_module, monkeypatch): + """Expanded-phase assembly should not reread a live parent container after snapshotting.""" + caching, _ = caching_module + original_sort_key = caching._sanitized_sort_key + outer = [{"marker"}, 2] + + def mutating_sort_key(obj, *args, **kwargs): + """Mutate the live parent while a child container is being canonicalized.""" + if obj == "marker": + outer[1] = 3 + return original_sort_key(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitized_sort_key", mutating_sort_key) + + hashable = caching.to_hashable(outer) + + assert hashable == ("list", (("set", ("marker",)), 2)) + assert outer[1] == 3 + + +def test_to_hashable_fails_closed_for_ordered_container_with_opaque_child(caching_module): + """Ordered containers should fail closed when a child cannot be canonicalized.""" + caching, _ = caching_module + + result = caching.to_hashable([object()]) + + assert isinstance(result, caching.Unhashable) + + +def test_to_hashable_canonicalizes_dict_insertion_order(caching_module): + """Dicts with the same content should hash identically regardless of insertion order.""" + caching, _ = caching_module + + first = {"b": 2, "a": 1} + second = {"a": 1, "b": 2} + + assert caching.to_hashable(first) == ("dict", (("a", 1), ("b", 2))) + assert caching.to_hashable(first) == caching.to_hashable(second) + + +def test_to_hashable_fails_closed_for_opaque_dict_key(caching_module): + """Opaque dict keys should fail closed instead of being traversed during hashing.""" + caching, _ = caching_module + + hashable = caching.to_hashable({_OpaqueValue(): 1}) + + assert isinstance(hashable, caching.Unhashable) + + +@pytest.mark.parametrize( + "container_factory", + [ + set, + frozenset, + ], +) +def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade unordered hash conversion to Unhashable.""" + caching, _ = caching_module + + def raising_sort_key(obj, *args, **kwargs): + """Raise a traversal RuntimeError while unordered values are canonicalized.""" + raise RuntimeError("container changed during iteration") + + monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key) + + hashable = caching.to_hashable(container_factory({"value"})) + + assert isinstance(hashable, caching.Unhashable) + + +def test_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch): + """Ambiguous dict key ordering should fail closed instead of using insertion order.""" + caching, _ = caching_module + original_sort_key = caching._sanitized_sort_key + ambiguous = {"a": 1, "b": 1} + + def colliding_sort_key(obj, *args, **kwargs): + """Force two distinct primitive keys to share the same ordering key.""" + if obj == "a" or obj == "b": + return ("COLLIDE",) + return original_sort_key(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key) + + hashable = caching.to_hashable(ambiguous) + + assert isinstance(hashable, caching.Unhashable) + + +def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch): + """Ambiguous dict sort ties should fail closed instead of depending on input order.""" + caching, _ = caching_module + original_sort_key = caching._primitive_signature_sort_key + ambiguous = {"a": 1, "b": 1} + + def colliding_sort_key(obj): + """Force two distinct primitive keys to share the same ordering key.""" + if obj == "a" or obj == "b": + return ("COLLIDE",) + return original_sort_key(obj) + + monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key) + + sanitized = caching._signature_to_hashable(ambiguous) + + assert isinstance(sanitized, caching.Unhashable) + + +def test_signature_to_hashable_fails_closed_for_opaque_dict_key(caching_module): + """Opaque dict keys should fail closed instead of being recursively canonicalized.""" + caching, _ = caching_module + + sanitized = caching._signature_to_hashable({_OpaqueValue(): 1}) + + assert isinstance(sanitized, caching.Unhashable) + + +def test_signature_to_hashable_fails_closed_on_dict_key_sort_collisions_even_with_distinct_values(caching_module, monkeypatch): + """Different values must not mask dict key-sort collisions during canonicalization.""" + caching, _ = caching_module + original_sort_key = caching._primitive_signature_sort_key + + def colliding_sort_key(obj): + """Force two distinct primitive keys to share the same ordering key.""" + if obj == "a" or obj == "b": + return ("COLLIDE",) + return original_sort_key(obj) + + monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key) + + sanitized = caching._signature_to_hashable({"a": 1, "b": 2}) + + assert isinstance(sanitized, caching.Unhashable) + + +@pytest.mark.parametrize( + "container_factory", + [ + set, + frozenset, + ], +) +def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, monkeypatch, container_factory): + """Ambiguous unordered values should fail closed instead of depending on iteration order.""" + caching, _ = caching_module + original_sort_key = caching._sanitized_sort_key + container = container_factory({"a", "b"}) + + def colliding_sort_key(obj, *args, **kwargs): + """Force two distinct primitive values to share the same ordering key.""" + if obj == "a" or obj == "b": + return ("COLLIDE",) + return original_sort_key(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key) + + hashable = caching.to_hashable(container) + + assert isinstance(hashable, caching.Unhashable) + + +def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch): + """Tainted full signatures should fail closed before `to_hashable()` runs.""" + caching, nodes_module = caching_module + monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) + monkeypatch.setattr( + caching, + "to_hashable", + lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"), + ) + + is_changed_value = [] + is_changed_value.append(is_changed_value) + + dynprompt = _FakeDynPrompt( + { + "node": { + "class_type": "UnitTestNode", + "inputs": {"value": 5}, + } + } + ) + key_set = caching.CacheKeySetInputSignature( + dynprompt, + ["node"], + _FakeIsChangedCache({"node": is_changed_value}), + ) + + signature = asyncio.run(key_set.get_node_signature(dynprompt, "node")) + + assert isinstance(signature, caching.Unhashable) + + +def test_shallow_is_changed_signature_accepts_primitive_lists(caching_module): + """Primitive-only `is_changed` lists should stay hashable without deep descent.""" + caching, _ = caching_module + + sanitized = caching._shallow_is_changed_signature([1, "two", None, True]) + + assert sanitized == ("is_changed_list", (1, "two", None, True)) + + +def test_shallow_is_changed_signature_accepts_structured_builtin_fingerprint_lists(caching_module): + """Structured built-in `is_changed` fingerprints should remain representable.""" + caching, _ = caching_module + + sanitized = caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}]) + + assert sanitized == ( + "is_changed_list", + ( + ("tuple", ("seed", 42)), + ("dict", (("cfg", 8),)), + ), + ) + + +def test_shallow_is_changed_signature_fails_closed_for_opaque_payload(caching_module): + """Opaque `is_changed` payloads should still fail closed.""" + caching, _ = caching_module + + sanitized = caching._shallow_is_changed_signature([_OpaqueValue()]) + + assert isinstance(sanitized, caching.Unhashable) + + +def test_get_immediate_node_signature_fails_closed_for_unhashable_is_changed(caching_module, monkeypatch): + """Recursive `is_changed` payloads should fail the full fragment closed.""" + caching, nodes_module = caching_module + monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) + + is_changed_value = [] + is_changed_value.append(is_changed_value) + dynprompt = _FakeDynPrompt( + { + "node": { + "class_type": "UnitTestNode", + "inputs": {"value": 5}, + } + } + ) + key_set = caching.CacheKeySetInputSignature( + dynprompt, + ["node"], + _FakeIsChangedCache({"node": is_changed_value}), + ) + + signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "node", {})) + + assert isinstance(signature, caching.Unhashable) + + +def test_get_immediate_node_signature_fails_closed_for_missing_node(caching_module): + """Missing nodes should return the fail-closed sentinel instead of a NaN tuple.""" + caching, _ = caching_module + dynprompt = _FakeDynPrompt({}) + key_set = caching.CacheKeySetInputSignature( + dynprompt, + [], + _FakeIsChangedCache({}), + ) + + signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "missing", {})) + + assert isinstance(signature, caching.Unhashable) diff --git a/tests/execution/test_caching.py b/tests/execution/test_caching.py new file mode 100644 index 000000000..569bf5bd8 --- /dev/null +++ b/tests/execution/test_caching.py @@ -0,0 +1,198 @@ +import asyncio + +from comfy_execution import caching + + +class _StubDynPrompt: + def __init__(self, nodes): + self._nodes = nodes + + def has_node(self, node_id): + return node_id in self._nodes + + def get_node(self, node_id): + return self._nodes[node_id] + + +class _StubIsChangedCache: + async def get(self, node_id): + return None + + +class _StubNode: + @classmethod + def INPUT_TYPES(cls): + return {"required": {}} + + +def test_get_immediate_node_signature_canonicalizes_non_link_inputs(monkeypatch): + live_value = [1, {"nested": [2, 3]}] + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {})) + + assert signature == ( + "TestCacheNode", + None, + ("value", ("list", (1, ("dict", (("nested", ("list", (2, 3))),))))), + ) + + +def test_to_hashable_walks_dicts_without_rebinding_traversal_stack(): + live_value = { + "outer": {"nested": [2, 3]}, + "items": [{"leaf": 4}], + } + + assert caching.to_hashable(live_value) == ( + "dict", + ( + ("items", ("list", (("dict", (("leaf", 4),)),))), + ("outer", ("dict", (("nested", ("list", (2, 3))),))), + ), + ) + + +def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(monkeypatch): + class OpaqueRuntimeValue: + pass + + live_value = OpaqueRuntimeValue() + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {})) + + assert isinstance(signature, caching.Unhashable) + + +def test_get_node_signature_propagates_unhashable_immediate_fragment(monkeypatch): + class OpaqueRuntimeValue: + pass + + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": OpaqueRuntimeValue()}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_node_signature(dynprompt, "1")) + + assert isinstance(signature, caching.Unhashable) + + +def test_get_node_signature_never_visits_raw_non_link_input(monkeypatch): + live_value = [1, 2, 3] + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + monkeypatch.setattr( + caching, + "_signature_to_hashable", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("outer signature canonicalizer should not run") + ), + ) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_node_signature(dynprompt, "1")) + + assert isinstance(signature, tuple) + + +def test_get_node_signature_keeps_deep_canonicalized_input_fragment(monkeypatch): + live_value = 1 + for _ in range(8): + live_value = [live_value] + expected = caching.to_hashable(live_value) + + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_node_signature(dynprompt, "1")) + + assert isinstance(signature, tuple) + assert signature[0][2][0] == "value" + assert signature[0][2][1] == expected + + +def test_get_node_signature_keeps_large_precanonicalized_fragment(monkeypatch): + live_value = object() + canonical_fragment = ("tuple", tuple(("list", (index, index + 1)) for index in range(256))) + dynprompt = _StubDynPrompt( + { + "1": { + "class_type": "TestCacheNode", + "inputs": {"value": live_value}, + } + } + ) + + monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode) + monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {}) + monkeypatch.setattr( + caching, + "to_hashable", + lambda value, max_nodes=caching._MAX_SIGNATURE_CONTAINER_VISITS: ( + canonical_fragment if value is live_value else caching.Unhashable() + ), + ) + monkeypatch.setattr( + caching, + "_signature_to_hashable", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("outer signature canonicalizer should not run") + ), + ) + + keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache()) + signature = asyncio.run(keyset.get_node_signature(dynprompt, "1")) + + assert isinstance(signature, tuple) + assert signature[0][2] == ("value", canonical_fragment)