Change signature cache to bail early

This commit is contained in:
xmarre 2026-03-15 04:31:32 +01:00
parent fadd79ad48
commit 9feb26928c
2 changed files with 79 additions and 31 deletions

View File

@ -66,6 +66,12 @@ _MAX_SIGNATURE_DEPTH = 32
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
def _mark_signature_tainted(taint_state):
"""Record that signature sanitization hit a fail-closed condition."""
if taint_state is not None:
taint_state["tainted"] = True
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
"""Return a deterministic ordering key for sanitized built-in container content."""
if depth >= max_depth:
@ -117,15 +123,20 @@ def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=Non
return result
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None, taint_state=None):
"""Normalize signature inputs to safe built-in containers.
Preserves built-in container type, replaces opaque runtime values with
Unhashable(), stops safely on cycles or excessive depth, and memoizes
repeated built-in substructures so shared DAG-like inputs do not explode
into repeated recursive work.
Unhashable(), stops safely on cycles or excessive depth, memoizes repeated
built-in substructures so shared DAG-like inputs do not explode into
repeated recursive work, and optionally records when sanitization had to
fail closed anywhere in the traversed structure.
"""
if taint_state is not None and taint_state.get("tainted"):
return Unhashable()
if depth >= max_depth:
_mark_signature_tainted(taint_state)
return Unhashable()
if active is None:
@ -139,16 +150,19 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return obj
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
_mark_signature_tainted(taint_state)
return Unhashable()
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
_mark_signature_tainted(taint_state)
return Unhashable()
budget["remaining"] -= 1
if budget["remaining"] < 0:
_mark_signature_tainted(taint_state)
return Unhashable()
active.add(obj_id)
@ -159,8 +173,8 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
sort_memo = {}
sanitized_items = [
(
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget, taint_state),
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget, taint_state),
)
for key, value in items
]
@ -181,34 +195,40 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
previous_sort_key, previous_item = ordered_items[index - 1]
current_sort_key, current_item = ordered_items[index]
if previous_sort_key == current_sort_key and previous_item != current_item:
_mark_signature_tainted(taint_state)
break
else:
result = {key: value for _, (key, value) in ordered_items}
except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable()
elif obj_type is list:
try:
items = list(obj)
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items]
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items]
except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable()
elif obj_type is tuple:
try:
items = list(obj)
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items)
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable()
elif obj_type is set:
try:
items = list(obj)
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items}
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items}
except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable()
else:
try:
items = list(obj)
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items)
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable()
finally:
active.discard(obj_id)
@ -377,7 +397,10 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
signature = _sanitize_signature_input(signature)
taint_state = {"tainted": False}
signature = _sanitize_signature_input(signature, taint_state=taint_state)
if taint_state["tainted"]:
return Unhashable()
return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):

View File

@ -49,22 +49,6 @@ class _OpaqueValue:
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
def _contains_unhashable(value, unhashable_type):
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
if isinstance(value, unhashable_type):
return True
value_type = type(value)
if value_type is dict:
return any(
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
for key, item in value.items()
)
if value_type in (list, tuple, set, frozenset):
return any(_contains_unhashable(item, unhashable_type) for item in value)
return False
@pytest.fixture
def caching_module(monkeypatch):
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
@ -105,6 +89,43 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m
assert sanitized[0][1]["value"] == 2
def test_sanitize_signature_input_marks_tainted_on_opaque_values(caching_module):
"""Opaque values should mark the containing signature as tainted."""
caching, _ = caching_module
taint_state = {"tainted": False}
sanitized = caching._sanitize_signature_input(["safe", object()], taint_state=taint_state)
assert isinstance(sanitized, list)
assert taint_state["tainted"] is True
assert isinstance(sanitized[1], caching.Unhashable)
def test_sanitize_signature_input_stops_descending_after_taint(caching_module, monkeypatch):
"""Once tainted, later recursive calls should return immediately without deeper descent."""
caching, _ = caching_module
original = caching._sanitize_signature_input
marker = object()
marker_seen = False
def tracking_sanitize(obj, *args, **kwargs):
"""Track whether recursion reaches the nested marker after tainting."""
nonlocal marker_seen
if obj is marker:
marker_seen = True
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_sanitize_signature_input", tracking_sanitize)
taint_state = {"tainted": False}
sanitized = original([object(), [marker]], taint_state=taint_state)
assert isinstance(sanitized, list)
assert taint_state["tainted"] is True
assert marker_seen is False
assert isinstance(sanitized[1], caching.Unhashable)
def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch):
"""List sanitization should read a point-in-time snapshot before recursive descent."""
caching, _ = caching_module
@ -241,10 +262,15 @@ def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module,
assert isinstance(hashable, caching.Unhashable)
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch):
"""Recursive `is_changed` payloads should be sanitized inside the full node signature."""
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
monkeypatch.setattr(
caching,
"to_hashable",
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
)
is_changed_value = []
is_changed_value.append(is_changed_value)
@ -265,5 +291,4 @@ def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
assert signature[0] == "list"
assert _contains_unhashable(signature, caching.Unhashable)
assert isinstance(signature, caching.Unhashable)