mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 22:58:19 +08:00
Change signature cache to bail early
This commit is contained in:
parent
fadd79ad48
commit
9feb26928c
@ -66,6 +66,12 @@ _MAX_SIGNATURE_DEPTH = 32
|
||||
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||
|
||||
|
||||
def _mark_signature_tainted(taint_state):
|
||||
"""Record that signature sanitization hit a fail-closed condition."""
|
||||
if taint_state is not None:
|
||||
taint_state["tainted"] = True
|
||||
|
||||
|
||||
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
||||
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||
if depth >= max_depth:
|
||||
@ -117,15 +123,20 @@ def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=Non
|
||||
return result
|
||||
|
||||
|
||||
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
|
||||
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None, taint_state=None):
|
||||
"""Normalize signature inputs to safe built-in containers.
|
||||
|
||||
Preserves built-in container type, replaces opaque runtime values with
|
||||
Unhashable(), stops safely on cycles or excessive depth, and memoizes
|
||||
repeated built-in substructures so shared DAG-like inputs do not explode
|
||||
into repeated recursive work.
|
||||
Unhashable(), stops safely on cycles or excessive depth, memoizes repeated
|
||||
built-in substructures so shared DAG-like inputs do not explode into
|
||||
repeated recursive work, and optionally records when sanitization had to
|
||||
fail closed anywhere in the traversed structure.
|
||||
"""
|
||||
if taint_state is not None and taint_state.get("tainted"):
|
||||
return Unhashable()
|
||||
|
||||
if depth >= max_depth:
|
||||
_mark_signature_tainted(taint_state)
|
||||
return Unhashable()
|
||||
|
||||
if active is None:
|
||||
@ -139,16 +150,19 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
|
||||
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return obj
|
||||
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
_mark_signature_tainted(taint_state)
|
||||
return Unhashable()
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
if obj_id in active:
|
||||
_mark_signature_tainted(taint_state)
|
||||
return Unhashable()
|
||||
|
||||
budget["remaining"] -= 1
|
||||
if budget["remaining"] < 0:
|
||||
_mark_signature_tainted(taint_state)
|
||||
return Unhashable()
|
||||
|
||||
active.add(obj_id)
|
||||
@ -159,8 +173,8 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
|
||||
sort_memo = {}
|
||||
sanitized_items = [
|
||||
(
|
||||
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
|
||||
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
|
||||
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget, taint_state),
|
||||
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget, taint_state),
|
||||
)
|
||||
for key, value in items
|
||||
]
|
||||
@ -181,34 +195,40 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
|
||||
previous_sort_key, previous_item = ordered_items[index - 1]
|
||||
current_sort_key, current_item = ordered_items[index]
|
||||
if previous_sort_key == current_sort_key and previous_item != current_item:
|
||||
_mark_signature_tainted(taint_state)
|
||||
break
|
||||
else:
|
||||
result = {key: value for _, (key, value) in ordered_items}
|
||||
except RuntimeError:
|
||||
_mark_signature_tainted(taint_state)
|
||||
result = Unhashable()
|
||||
elif obj_type is list:
|
||||
try:
|
||||
items = list(obj)
|
||||
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items]
|
||||
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items]
|
||||
except RuntimeError:
|
||||
_mark_signature_tainted(taint_state)
|
||||
result = Unhashable()
|
||||
elif obj_type is tuple:
|
||||
try:
|
||||
items = list(obj)
|
||||
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items)
|
||||
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
|
||||
except RuntimeError:
|
||||
_mark_signature_tainted(taint_state)
|
||||
result = Unhashable()
|
||||
elif obj_type is set:
|
||||
try:
|
||||
items = list(obj)
|
||||
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items}
|
||||
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items}
|
||||
except RuntimeError:
|
||||
_mark_signature_tainted(taint_state)
|
||||
result = Unhashable()
|
||||
else:
|
||||
try:
|
||||
items = list(obj)
|
||||
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items)
|
||||
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
|
||||
except RuntimeError:
|
||||
_mark_signature_tainted(taint_state)
|
||||
result = Unhashable()
|
||||
finally:
|
||||
active.discard(obj_id)
|
||||
@ -377,7 +397,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
signature = _sanitize_signature_input(signature)
|
||||
taint_state = {"tainted": False}
|
||||
signature = _sanitize_signature_input(signature, taint_state=taint_state)
|
||||
if taint_state["tainted"]:
|
||||
return Unhashable()
|
||||
return to_hashable(signature)
|
||||
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
|
||||
@ -49,22 +49,6 @@ class _OpaqueValue:
|
||||
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
||||
|
||||
|
||||
def _contains_unhashable(value, unhashable_type):
|
||||
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
|
||||
if isinstance(value, unhashable_type):
|
||||
return True
|
||||
|
||||
value_type = type(value)
|
||||
if value_type is dict:
|
||||
return any(
|
||||
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
|
||||
for key, item in value.items()
|
||||
)
|
||||
if value_type in (list, tuple, set, frozenset):
|
||||
return any(_contains_unhashable(item, unhashable_type) for item in value)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caching_module(monkeypatch):
|
||||
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
||||
@ -105,6 +89,43 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m
|
||||
assert sanitized[0][1]["value"] == 2
|
||||
|
||||
|
||||
def test_sanitize_signature_input_marks_tainted_on_opaque_values(caching_module):
|
||||
"""Opaque values should mark the containing signature as tainted."""
|
||||
caching, _ = caching_module
|
||||
taint_state = {"tainted": False}
|
||||
|
||||
sanitized = caching._sanitize_signature_input(["safe", object()], taint_state=taint_state)
|
||||
|
||||
assert isinstance(sanitized, list)
|
||||
assert taint_state["tainted"] is True
|
||||
assert isinstance(sanitized[1], caching.Unhashable)
|
||||
|
||||
|
||||
def test_sanitize_signature_input_stops_descending_after_taint(caching_module, monkeypatch):
|
||||
"""Once tainted, later recursive calls should return immediately without deeper descent."""
|
||||
caching, _ = caching_module
|
||||
original = caching._sanitize_signature_input
|
||||
marker = object()
|
||||
marker_seen = False
|
||||
|
||||
def tracking_sanitize(obj, *args, **kwargs):
|
||||
"""Track whether recursion reaches the nested marker after tainting."""
|
||||
nonlocal marker_seen
|
||||
if obj is marker:
|
||||
marker_seen = True
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_sanitize_signature_input", tracking_sanitize)
|
||||
|
||||
taint_state = {"tainted": False}
|
||||
sanitized = original([object(), [marker]], taint_state=taint_state)
|
||||
|
||||
assert isinstance(sanitized, list)
|
||||
assert taint_state["tainted"] is True
|
||||
assert marker_seen is False
|
||||
assert isinstance(sanitized[1], caching.Unhashable)
|
||||
|
||||
|
||||
def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch):
|
||||
"""List sanitization should read a point-in-time snapshot before recursive descent."""
|
||||
caching, _ = caching_module
|
||||
@ -241,10 +262,15 @@ def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module,
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch):
|
||||
"""Recursive `is_changed` payloads should be sanitized inside the full node signature."""
|
||||
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
|
||||
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
|
||||
caching, nodes_module = caching_module
|
||||
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"to_hashable",
|
||||
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
|
||||
)
|
||||
|
||||
is_changed_value = []
|
||||
is_changed_value.append(is_changed_value)
|
||||
@ -265,5 +291,4 @@ def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch
|
||||
|
||||
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||
|
||||
assert signature[0] == "list"
|
||||
assert _contains_unhashable(signature, caching.Unhashable)
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user