mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 09:33:29 +08:00
Change signature cache to bail early
This commit is contained in:
parent
fadd79ad48
commit
9feb26928c
@ -66,6 +66,12 @@ _MAX_SIGNATURE_DEPTH = 32
|
|||||||
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||||
|
|
||||||
|
|
||||||
|
def _mark_signature_tainted(taint_state):
|
||||||
|
"""Record that signature sanitization hit a fail-closed condition."""
|
||||||
|
if taint_state is not None:
|
||||||
|
taint_state["tainted"] = True
|
||||||
|
|
||||||
|
|
||||||
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
||||||
"""Return a deterministic ordering key for sanitized built-in container content."""
|
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
@ -117,15 +123,20 @@ def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=Non
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
|
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None, taint_state=None):
|
||||||
"""Normalize signature inputs to safe built-in containers.
|
"""Normalize signature inputs to safe built-in containers.
|
||||||
|
|
||||||
Preserves built-in container type, replaces opaque runtime values with
|
Preserves built-in container type, replaces opaque runtime values with
|
||||||
Unhashable(), stops safely on cycles or excessive depth, and memoizes
|
Unhashable(), stops safely on cycles or excessive depth, memoizes repeated
|
||||||
repeated built-in substructures so shared DAG-like inputs do not explode
|
built-in substructures so shared DAG-like inputs do not explode into
|
||||||
into repeated recursive work.
|
repeated recursive work, and optionally records when sanitization had to
|
||||||
|
fail closed anywhere in the traversed structure.
|
||||||
"""
|
"""
|
||||||
|
if taint_state is not None and taint_state.get("tainted"):
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
if active is None:
|
if active is None:
|
||||||
@ -139,16 +150,19 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
|
|||||||
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
return obj
|
return obj
|
||||||
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
obj_id = id(obj)
|
obj_id = id(obj)
|
||||||
if obj_id in memo:
|
if obj_id in memo:
|
||||||
return memo[obj_id]
|
return memo[obj_id]
|
||||||
if obj_id in active:
|
if obj_id in active:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
budget["remaining"] -= 1
|
budget["remaining"] -= 1
|
||||||
if budget["remaining"] < 0:
|
if budget["remaining"] < 0:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
active.add(obj_id)
|
active.add(obj_id)
|
||||||
@ -159,8 +173,8 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
|
|||||||
sort_memo = {}
|
sort_memo = {}
|
||||||
sanitized_items = [
|
sanitized_items = [
|
||||||
(
|
(
|
||||||
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
|
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget, taint_state),
|
||||||
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
|
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget, taint_state),
|
||||||
)
|
)
|
||||||
for key, value in items
|
for key, value in items
|
||||||
]
|
]
|
||||||
@ -181,34 +195,40 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
|
|||||||
previous_sort_key, previous_item = ordered_items[index - 1]
|
previous_sort_key, previous_item = ordered_items[index - 1]
|
||||||
current_sort_key, current_item = ordered_items[index]
|
current_sort_key, current_item = ordered_items[index]
|
||||||
if previous_sort_key == current_sort_key and previous_item != current_item:
|
if previous_sort_key == current_sort_key and previous_item != current_item:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
result = {key: value for _, (key, value) in ordered_items}
|
result = {key: value for _, (key, value) in ordered_items}
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
result = Unhashable()
|
result = Unhashable()
|
||||||
elif obj_type is list:
|
elif obj_type is list:
|
||||||
try:
|
try:
|
||||||
items = list(obj)
|
items = list(obj)
|
||||||
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items]
|
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items]
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
result = Unhashable()
|
result = Unhashable()
|
||||||
elif obj_type is tuple:
|
elif obj_type is tuple:
|
||||||
try:
|
try:
|
||||||
items = list(obj)
|
items = list(obj)
|
||||||
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items)
|
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
result = Unhashable()
|
result = Unhashable()
|
||||||
elif obj_type is set:
|
elif obj_type is set:
|
||||||
try:
|
try:
|
||||||
items = list(obj)
|
items = list(obj)
|
||||||
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items}
|
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items}
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
result = Unhashable()
|
result = Unhashable()
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
items = list(obj)
|
items = list(obj)
|
||||||
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items)
|
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget, taint_state) for item in items)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
_mark_signature_tainted(taint_state)
|
||||||
result = Unhashable()
|
result = Unhashable()
|
||||||
finally:
|
finally:
|
||||||
active.discard(obj_id)
|
active.discard(obj_id)
|
||||||
@ -377,7 +397,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||||
for ancestor_id in ancestors:
|
for ancestor_id in ancestors:
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||||
signature = _sanitize_signature_input(signature)
|
taint_state = {"tainted": False}
|
||||||
|
signature = _sanitize_signature_input(signature, taint_state=taint_state)
|
||||||
|
if taint_state["tainted"]:
|
||||||
|
return Unhashable()
|
||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
|||||||
@ -49,22 +49,6 @@ class _OpaqueValue:
|
|||||||
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
||||||
|
|
||||||
|
|
||||||
def _contains_unhashable(value, unhashable_type):
|
|
||||||
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
|
|
||||||
if isinstance(value, unhashable_type):
|
|
||||||
return True
|
|
||||||
|
|
||||||
value_type = type(value)
|
|
||||||
if value_type is dict:
|
|
||||||
return any(
|
|
||||||
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
|
|
||||||
for key, item in value.items()
|
|
||||||
)
|
|
||||||
if value_type in (list, tuple, set, frozenset):
|
|
||||||
return any(_contains_unhashable(item, unhashable_type) for item in value)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def caching_module(monkeypatch):
|
def caching_module(monkeypatch):
|
||||||
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
||||||
@ -105,6 +89,43 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m
|
|||||||
assert sanitized[0][1]["value"] == 2
|
assert sanitized[0][1]["value"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_signature_input_marks_tainted_on_opaque_values(caching_module):
|
||||||
|
"""Opaque values should mark the containing signature as tainted."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
taint_state = {"tainted": False}
|
||||||
|
|
||||||
|
sanitized = caching._sanitize_signature_input(["safe", object()], taint_state=taint_state)
|
||||||
|
|
||||||
|
assert isinstance(sanitized, list)
|
||||||
|
assert taint_state["tainted"] is True
|
||||||
|
assert isinstance(sanitized[1], caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_signature_input_stops_descending_after_taint(caching_module, monkeypatch):
|
||||||
|
"""Once tainted, later recursive calls should return immediately without deeper descent."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original = caching._sanitize_signature_input
|
||||||
|
marker = object()
|
||||||
|
marker_seen = False
|
||||||
|
|
||||||
|
def tracking_sanitize(obj, *args, **kwargs):
|
||||||
|
"""Track whether recursion reaches the nested marker after tainting."""
|
||||||
|
nonlocal marker_seen
|
||||||
|
if obj is marker:
|
||||||
|
marker_seen = True
|
||||||
|
return original(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_sanitize_signature_input", tracking_sanitize)
|
||||||
|
|
||||||
|
taint_state = {"tainted": False}
|
||||||
|
sanitized = original([object(), [marker]], taint_state=taint_state)
|
||||||
|
|
||||||
|
assert isinstance(sanitized, list)
|
||||||
|
assert taint_state["tainted"] is True
|
||||||
|
assert marker_seen is False
|
||||||
|
assert isinstance(sanitized[1], caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch):
|
def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch):
|
||||||
"""List sanitization should read a point-in-time snapshot before recursive descent."""
|
"""List sanitization should read a point-in-time snapshot before recursive descent."""
|
||||||
caching, _ = caching_module
|
caching, _ = caching_module
|
||||||
@ -241,10 +262,15 @@ def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module,
|
|||||||
assert isinstance(hashable, caching.Unhashable)
|
assert isinstance(hashable, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch):
|
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
|
||||||
"""Recursive `is_changed` payloads should be sanitized inside the full node signature."""
|
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
|
||||||
caching, nodes_module = caching_module
|
caching, nodes_module = caching_module
|
||||||
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
caching,
|
||||||
|
"to_hashable",
|
||||||
|
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
|
||||||
|
)
|
||||||
|
|
||||||
is_changed_value = []
|
is_changed_value = []
|
||||||
is_changed_value.append(is_changed_value)
|
is_changed_value.append(is_changed_value)
|
||||||
@ -265,5 +291,4 @@ def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch
|
|||||||
|
|
||||||
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||||
|
|
||||||
assert signature[0] == "list"
|
assert isinstance(signature, caching.Unhashable)
|
||||||
assert _contains_unhashable(signature, caching.Unhashable)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user