Change signature cache to bail early

This commit is contained in:
xmarre 2026-03-15 04:31:32 +01:00
parent fadd79ad48
commit 9feb26928c
2 changed files with 79 additions and 31 deletions

View File

@ -66,6 +66,12 @@ _MAX_SIGNATURE_DEPTH = 32
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000 _MAX_SIGNATURE_CONTAINER_VISITS = 10_000
def _mark_signature_tainted(taint_state):
"""Record that signature sanitization hit a fail-closed condition."""
if taint_state is not None:
taint_state["tainted"] = True
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None): def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
"""Return a deterministic ordering key for sanitized built-in container content.""" """Return a deterministic ordering key for sanitized built-in container content."""
if depth >= max_depth: if depth >= max_depth:
@ -117,15 +123,20 @@ def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=Non
return result return result
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None): def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None, taint_state=None):
"""Normalize signature inputs to safe built-in containers. """Normalize signature inputs to safe built-in containers.
Preserves built-in container type, replaces opaque runtime values with Preserves built-in container type, replaces opaque runtime values with
Unhashable(), stops safely on cycles or excessive depth, and memoizes Unhashable(), stops safely on cycles or excessive depth, memoizes repeated
repeated built-in substructures so shared DAG-like inputs do not explode built-in substructures so shared DAG-like inputs do not explode into
into repeated recursive work. repeated recursive work, and optionally records when sanitization had to
fail closed anywhere in the traversed structure.
""" """
if taint_state is not None and taint_state.get("tainted"):
return Unhashable()
if depth >= max_depth: if depth >= max_depth:
_mark_signature_tainted(taint_state)
return Unhashable() return Unhashable()
if active is None: if active is None:
@ -139,16 +150,19 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
if obj_type in _PRIMITIVE_SIGNATURE_TYPES: if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return obj return obj
if obj_type not in _CONTAINER_SIGNATURE_TYPES: if obj_type not in _CONTAINER_SIGNATURE_TYPES:
_mark_signature_tainted(taint_state)
return Unhashable() return Unhashable()
obj_id = id(obj) obj_id = id(obj)
if obj_id in memo: if obj_id in memo:
return memo[obj_id] return memo[obj_id]
if obj_id in active: if obj_id in active:
_mark_signature_tainted(taint_state)
return Unhashable() return Unhashable()
budget["remaining"] -= 1 budget["remaining"] -= 1
if budget["remaining"] < 0: if budget["remaining"] < 0:
_mark_signature_tainted(taint_state)
return Unhashable() return Unhashable()
active.add(obj_id) active.add(obj_id)
@ -159,8 +173,8 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
sort_memo = {} sort_memo = {}
sanitized_items = [ sanitized_items = [
( (
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget), _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget, taint_state),
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget), _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget, taint_state),
) )
for key, value in items for key, value in items
] ]
@ -181,34 +195,40 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
previous_sort_key, previous_item = ordered_items[index - 1] previous_sort_key, previous_item = ordered_items[index - 1]
current_sort_key, current_item = ordered_items[index] current_sort_key, current_item = ordered_items[index]
if previous_sort_key == current_sort_key and previous_item != current_item: if previous_sort_key == current_sort_key and previous_item != current_item:
_mark_signature_tainted(taint_state)
break break
else: else:
result = {key: value for _, (key, value) in ordered_items} result = {key: value for _, (key, value) in ordered_items}
except RuntimeError: except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable() result = Unhashable()
elif obj_type is list: elif obj_type is list:
try: try:
items = list(obj) items = list(obj)
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items] result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items]
except RuntimeError: except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable() result = Unhashable()
elif obj_type is tuple: elif obj_type is tuple:
try: try:
items = list(obj) items = list(obj)
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items) result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
except RuntimeError: except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable() result = Unhashable()
elif obj_type is set: elif obj_type is set:
try: try:
items = list(obj) items = list(obj)
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items} result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items}
except RuntimeError: except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable() result = Unhashable()
else: else:
try: try:
items = list(obj) items = list(obj)
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items) result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
except RuntimeError: except RuntimeError:
_mark_signature_tainted(taint_state)
result = Unhashable() result = Unhashable()
finally: finally:
active.discard(obj_id) active.discard(obj_id)
@ -377,7 +397,10 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors: for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
signature = _sanitize_signature_input(signature) taint_state = {"tainted": False}
signature = _sanitize_signature_input(signature, taint_state=taint_state)
if taint_state["tainted"]:
return Unhashable()
return to_hashable(signature) return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):

View File

@ -49,22 +49,6 @@ class _OpaqueValue:
"""Hashable opaque object used to exercise fail-closed unordered hashing paths.""" """Hashable opaque object used to exercise fail-closed unordered hashing paths."""
def _contains_unhashable(value, unhashable_type):
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
if isinstance(value, unhashable_type):
return True
value_type = type(value)
if value_type is dict:
return any(
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
for key, item in value.items()
)
if value_type in (list, tuple, set, frozenset):
return any(_contains_unhashable(item, unhashable_type) for item in value)
return False
@pytest.fixture @pytest.fixture
def caching_module(monkeypatch): def caching_module(monkeypatch):
"""Import `comfy_execution.caching` with lightweight stub dependencies.""" """Import `comfy_execution.caching` with lightweight stub dependencies."""
@ -105,6 +89,43 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m
assert sanitized[0][1]["value"] == 2 assert sanitized[0][1]["value"] == 2
def test_sanitize_signature_input_marks_tainted_on_opaque_values(caching_module):
"""Opaque values should mark the containing signature as tainted."""
caching, _ = caching_module
taint_state = {"tainted": False}
sanitized = caching._sanitize_signature_input(["safe", object()], taint_state=taint_state)
assert isinstance(sanitized, list)
assert taint_state["tainted"] is True
assert isinstance(sanitized[1], caching.Unhashable)
def test_sanitize_signature_input_stops_descending_after_taint(caching_module, monkeypatch):
"""Once tainted, later recursive calls should return immediately without deeper descent."""
caching, _ = caching_module
original = caching._sanitize_signature_input
marker = object()
marker_seen = False
def tracking_sanitize(obj, *args, **kwargs):
"""Track whether recursion reaches the nested marker after tainting."""
nonlocal marker_seen
if obj is marker:
marker_seen = True
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_sanitize_signature_input", tracking_sanitize)
taint_state = {"tainted": False}
sanitized = original([object(), [marker]], taint_state=taint_state)
assert isinstance(sanitized, list)
assert taint_state["tainted"] is True
assert marker_seen is False
assert isinstance(sanitized[1], caching.Unhashable)
def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch): def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch):
"""List sanitization should read a point-in-time snapshot before recursive descent.""" """List sanitization should read a point-in-time snapshot before recursive descent."""
caching, _ = caching_module caching, _ = caching_module
@ -241,10 +262,15 @@ def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module,
assert isinstance(hashable, caching.Unhashable) assert isinstance(hashable, caching.Unhashable)
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch): def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
"""Recursive `is_changed` payloads should be sanitized inside the full node signature.""" """Tainted full signatures should fail closed before `to_hashable()` runs."""
caching, nodes_module = caching_module caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode) monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
monkeypatch.setattr(
caching,
"to_hashable",
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
)
is_changed_value = [] is_changed_value = []
is_changed_value.append(is_changed_value) is_changed_value.append(is_changed_value)
@ -265,5 +291,4 @@ def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node")) signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
assert signature[0] == "list" assert isinstance(signature, caching.Unhashable)
assert _contains_unhashable(signature, caching.Unhashable)