fail closed on ambiguous container ordering in cache signatures

This commit is contained in:
xmarre 2026-03-15 02:32:25 +01:00
parent 763089f681
commit aceaa5e579
2 changed files with 380 additions and 100 deletions

View File

@ -60,136 +60,239 @@ class Unhashable:
pass
def _sanitized_sort_key(obj, depth=0, max_depth=32):
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
_MAX_SIGNATURE_DEPTH = 32
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
"""Return a deterministic ordering key for sanitized built-in container content."""
if depth >= max_depth:
return ("MAX_DEPTH",)
if active is None:
active = set()
if memo is None:
memo = {}
obj_type = type(obj)
if obj_type is Unhashable:
return ("UNHASHABLE",)
elif obj_type in (int, float, str, bool, bytes, type(None)):
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
elif obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth),
_sanitized_sort_key(v, depth + 1, max_depth),
)
for k, v in obj.items()
]
items.sort()
return ("dict", tuple(items))
elif obj_type is list:
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is set:
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
elif obj_type is frozenset:
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
else:
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
return ("CYCLE",)
def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None):
active.add(obj_id)
try:
if obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
)
for k, v in obj.items()
]
items.sort()
result = ("dict", tuple(items))
elif obj_type is list:
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
elif obj_type is tuple:
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
elif obj_type is set:
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
else:
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
finally:
active.discard(obj_id)
memo[obj_id] = result
return result
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
"""Normalize signature inputs to safe built-in containers.
Preserves built-in container type, replaces opaque runtime values with
Unhashable(), and stops safely on cycles or excessive depth.
Unhashable(), stops safely on cycles or excessive depth, and memoizes
repeated built-in substructures so shared DAG-like inputs do not explode
into repeated recursive work.
"""
if depth >= max_depth:
return Unhashable()
if seen is None:
seen = set()
if active is None:
active = set()
if memo is None:
memo = {}
if budget is None:
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
obj_type = type(obj)
if obj_type in (dict, list, tuple, set, frozenset):
obj_id = id(obj)
if obj_id in seen:
return Unhashable()
next_seen = seen | {obj_id}
if obj_type in (int, float, str, bool, bytes, type(None)):
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return obj
elif obj_type is dict:
sanitized_items = [
(
_sanitize_signature_input(key, depth + 1, max_depth, next_seen),
_sanitize_signature_input(value, depth + 1, max_depth, next_seen),
)
for key, value in obj.items()
]
sanitized_items.sort(
key=lambda kv: (
_sanitized_sort_key(kv[0], depth + 1, max_depth),
_sanitized_sort_key(kv[1], depth + 1, max_depth),
)
)
return {key: value for key, value in sanitized_items}
elif obj_type is list:
return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj]
elif obj_type is tuple:
return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
elif obj_type is set:
return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj}
elif obj_type is frozenset:
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
else:
# Execution-cache signatures should be built from prompt-safe values.
# If a custom node injects a runtime object here, mark it unhashable so
# the node won't reuse stale cache entries across runs, but do not walk
# the foreign object and risk crashing on custom container semantics.
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
return Unhashable()
def to_hashable(obj, depth=0, max_depth=32, seen=None):
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
return Unhashable()
budget["remaining"] -= 1
if budget["remaining"] < 0:
return Unhashable()
active.add(obj_id)
try:
if obj_type is dict:
sort_memo = {}
sanitized_items = [
(
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
)
for key, value in obj.items()
]
ordered_items = [
(
(
_sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo),
_sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo),
),
(key, value),
)
for key, value in sanitized_items
]
ordered_items.sort(key=lambda item: item[0])
result = Unhashable()
for index in range(1, len(ordered_items)):
previous_sort_key, previous_item = ordered_items[index - 1]
current_sort_key, current_item = ordered_items[index]
if previous_sort_key == current_sort_key and previous_item != current_item:
break
else:
result = {key: value for _, (key, value) in ordered_items}
elif obj_type is list:
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj]
elif obj_type is tuple:
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj)
elif obj_type is set:
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj}
else:
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj)
finally:
active.discard(obj_id)
memo[obj_id] = result
return result
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
"""Convert sanitized prompt inputs into a stable hashable representation.
Preserves built-in container type and stops safely on cycles or excessive depth.
The input is expected to already be sanitized to plain built-in containers,
but this function still fails safe for anything unexpected. Traversal is
iterative and memoized so shared built-in substructures do not trigger
exponential re-walks during cache-key construction.
"""
if depth >= max_depth:
return Unhashable()
if seen is None:
seen = set()
# Restrict recursion to plain built-in containers. Some custom nodes insert
# runtime objects into prompt inputs for dynamic graph paths; walking those
# objects as generic Mappings / Sequences is unsafe and can destabilize the
# cache signature builder.
obj_type = type(obj)
if obj_type in (int, float, str, bool, bytes, type(None)):
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
return obj
if obj_type in (dict, list, tuple, set, frozenset):
obj_id = id(obj)
if obj_id in seen:
return Unhashable()
seen = seen | {obj_id}
if obj_type is dict:
return (
"dict",
frozenset(
(
to_hashable(k, depth + 1, max_depth, seen),
to_hashable(v, depth + 1, max_depth, seen),
)
for k, v in obj.items()
),
)
elif obj_type is list:
return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is set:
return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is frozenset:
return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
else:
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
return Unhashable()
memo = {}
active = set()
sort_memo = {}
processed = 0
stack = [(obj, False)]
def resolve_value(value):
"""Resolve a child value from the completed memo table when available."""
value_type = type(value)
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
return value
return memo.get(id(value), Unhashable())
def resolve_unordered_values(current, container_tag):
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
ordered_items = [
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
for item in current
]
ordered_items.sort(key=lambda item: item[0])
for index in range(1, len(ordered_items)):
previous_key, previous_value = ordered_items[index - 1]
current_key, current_value = ordered_items[index]
if previous_key == current_key and previous_value != current_value:
return Unhashable()
return (container_tag, tuple(value for _, value in ordered_items))
while stack:
current, expanded = stack.pop()
current_type = type(current)
if current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable:
continue
if current_type not in _CONTAINER_SIGNATURE_TYPES:
memo[id(current)] = Unhashable()
continue
current_id = id(current)
if current_id in memo:
continue
if expanded:
active.discard(current_id)
if current_type is dict:
memo[current_id] = (
"dict",
tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()),
)
elif current_type is list:
memo[current_id] = ("list", tuple(resolve_value(item) for item in current))
elif current_type is tuple:
memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current))
elif current_type is set:
memo[current_id] = resolve_unordered_values(current, "set")
else:
memo[current_id] = resolve_unordered_values(current, "frozenset")
continue
if current_id in active:
memo[current_id] = Unhashable()
continue
processed += 1
if processed > max_nodes:
return Unhashable()
active.add(current_id)
stack.append((current, True))
if current_type is dict:
items = list(current.items())
for key, value in reversed(items):
stack.append((value, False))
stack.append((key, False))
else:
items = list(current)
for item in reversed(items):
stack.append((item, False))
return memo.get(id(obj), Unhashable())
class CacheKeySetID(CacheKeySet):
"""Cache-key strategy that keys nodes by node id and class type."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
@ -238,6 +341,7 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
signature = _sanitize_signature_input(signature)
return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):

View File

@ -0,0 +1,176 @@
"""Unit tests for cache-signature sanitization and hash conversion hardening."""
import asyncio
import importlib
import sys
import types
import pytest
class _DummyNode:
"""Minimal node stub used to satisfy cache-signature class lookups."""
@staticmethod
def INPUT_TYPES():
"""Return a minimal empty input schema for unit tests."""
return {"required": {}}
class _FakeDynPrompt:
"""Small DynamicPrompt stand-in with only the methods these tests need."""
def __init__(self, nodes_by_id):
"""Store test nodes by id."""
self._nodes_by_id = nodes_by_id
def has_node(self, node_id):
"""Return whether the fake prompt contains the requested node."""
return node_id in self._nodes_by_id
def get_node(self, node_id):
"""Return the stored node payload for the requested id."""
return self._nodes_by_id[node_id]
class _FakeIsChangedCache:
"""Async stub for `is_changed` lookups used by cache-key generation."""
def __init__(self, values):
"""Store canned `is_changed` responses keyed by node id."""
self._values = values
async def get(self, node_id):
"""Return the canned `is_changed` value for a node."""
return self._values[node_id]
class _OpaqueValue:
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
def _contains_unhashable(value, unhashable_type):
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
if isinstance(value, unhashable_type):
return True
value_type = type(value)
if value_type is dict:
return any(
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
for key, item in value.items()
)
if value_type in (list, tuple, set, frozenset):
return any(_contains_unhashable(item, unhashable_type) for item in value)
return False
@pytest.fixture
def caching_module(monkeypatch):
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
torch_module = types.ModuleType("torch")
psutil_module = types.ModuleType("psutil")
nodes_module = types.ModuleType("nodes")
nodes_module.NODE_CLASS_MAPPINGS = {}
graph_module = types.ModuleType("comfy_execution.graph")
class DynamicPrompt:
"""Placeholder graph type so the caching module can import cleanly."""
pass
graph_module.DynamicPrompt = DynamicPrompt
monkeypatch.setitem(sys.modules, "torch", torch_module)
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
module = importlib.import_module("comfy_execution.caching")
module = importlib.reload(module)
return module, nodes_module
def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_module):
"""Shared built-in substructures should sanitize without collapsing to Unhashable."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
sanitized = caching._sanitize_signature_input([shared, shared])
assert isinstance(sanitized, list)
assert sanitized[0] == sanitized[1]
assert sanitized[0][0]["value"] == 1
assert sanitized[0][1]["value"] == 2
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
"""Repeated sanitized content should hash stably for shared substructures."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
sanitized = caching._sanitize_signature_input([shared, shared])
hashable = caching.to_hashable(sanitized)
assert hashable[0] == "list"
assert hashable[1][0] == hashable[1][1]
assert hashable[1][0][0] == "list"
def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module):
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
caching, _ = caching_module
ambiguous = {
_OpaqueValue(): _OpaqueValue(),
_OpaqueValue(): _OpaqueValue(),
}
sanitized = caching._sanitize_signature_input(ambiguous)
assert isinstance(sanitized, caching.Unhashable)
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory):
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
caching, _ = caching_module
container = container_factory({_OpaqueValue(), _OpaqueValue()})
hashable = caching.to_hashable(container)
assert isinstance(hashable, caching.Unhashable)
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch):
"""Recursive `is_changed` payloads should be sanitized inside the full node signature."""
caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
is_changed_value = []
is_changed_value.append(is_changed_value)
dynprompt = _FakeDynPrompt(
{
"node": {
"class_type": "UnitTestNode",
"inputs": {"value": 5},
}
}
)
key_set = caching.CacheKeySetInputSignature(
dynprompt,
["node"],
_FakeIsChangedCache({"node": is_changed_value}),
)
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
assert signature[0] == "list"
assert _contains_unhashable(signature, caching.Unhashable)