mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 09:33:29 +08:00
fail closed on ambiguous container ordering in cache signatures
This commit is contained in:
parent
763089f681
commit
aceaa5e579
@ -60,136 +60,239 @@ class Unhashable:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _sanitized_sort_key(obj, depth=0, max_depth=32):
|
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
|
||||||
|
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
|
||||||
|
_MAX_SIGNATURE_DEPTH = 32
|
||||||
|
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
||||||
"""Return a deterministic ordering key for sanitized built-in container content."""
|
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
return ("MAX_DEPTH",)
|
return ("MAX_DEPTH",)
|
||||||
|
|
||||||
|
if active is None:
|
||||||
|
active = set()
|
||||||
|
if memo is None:
|
||||||
|
memo = {}
|
||||||
|
|
||||||
obj_type = type(obj)
|
obj_type = type(obj)
|
||||||
if obj_type is Unhashable:
|
if obj_type is Unhashable:
|
||||||
return ("UNHASHABLE",)
|
return ("UNHASHABLE",)
|
||||||
elif obj_type in (int, float, str, bool, bytes, type(None)):
|
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
|
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||||
elif obj_type is dict:
|
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
items = [
|
|
||||||
(
|
|
||||||
_sanitized_sort_key(k, depth + 1, max_depth),
|
|
||||||
_sanitized_sort_key(v, depth + 1, max_depth),
|
|
||||||
)
|
|
||||||
for k, v in obj.items()
|
|
||||||
]
|
|
||||||
items.sort()
|
|
||||||
return ("dict", tuple(items))
|
|
||||||
elif obj_type is list:
|
|
||||||
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
|
|
||||||
elif obj_type is tuple:
|
|
||||||
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
|
|
||||||
elif obj_type is set:
|
|
||||||
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
|
|
||||||
elif obj_type is frozenset:
|
|
||||||
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
|
|
||||||
else:
|
|
||||||
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
|
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
|
||||||
|
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in memo:
|
||||||
|
return memo[obj_id]
|
||||||
|
if obj_id in active:
|
||||||
|
return ("CYCLE",)
|
||||||
|
|
||||||
def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None):
|
active.add(obj_id)
|
||||||
|
try:
|
||||||
|
if obj_type is dict:
|
||||||
|
items = [
|
||||||
|
(
|
||||||
|
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
|
||||||
|
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
|
||||||
|
)
|
||||||
|
for k, v in obj.items()
|
||||||
|
]
|
||||||
|
items.sort()
|
||||||
|
result = ("dict", tuple(items))
|
||||||
|
elif obj_type is list:
|
||||||
|
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||||
|
elif obj_type is tuple:
|
||||||
|
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||||
|
elif obj_type is set:
|
||||||
|
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||||
|
else:
|
||||||
|
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||||
|
finally:
|
||||||
|
active.discard(obj_id)
|
||||||
|
|
||||||
|
memo[obj_id] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
|
||||||
"""Normalize signature inputs to safe built-in containers.
|
"""Normalize signature inputs to safe built-in containers.
|
||||||
|
|
||||||
Preserves built-in container type, replaces opaque runtime values with
|
Preserves built-in container type, replaces opaque runtime values with
|
||||||
Unhashable(), and stops safely on cycles or excessive depth.
|
Unhashable(), stops safely on cycles or excessive depth, and memoizes
|
||||||
|
repeated built-in substructures so shared DAG-like inputs do not explode
|
||||||
|
into repeated recursive work.
|
||||||
"""
|
"""
|
||||||
if depth >= max_depth:
|
if depth >= max_depth:
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
if seen is None:
|
if active is None:
|
||||||
seen = set()
|
active = set()
|
||||||
|
if memo is None:
|
||||||
|
memo = {}
|
||||||
|
if budget is None:
|
||||||
|
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
|
||||||
|
|
||||||
obj_type = type(obj)
|
obj_type = type(obj)
|
||||||
if obj_type in (dict, list, tuple, set, frozenset):
|
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
obj_id = id(obj)
|
|
||||||
if obj_id in seen:
|
|
||||||
return Unhashable()
|
|
||||||
next_seen = seen | {obj_id}
|
|
||||||
|
|
||||||
if obj_type in (int, float, str, bool, bytes, type(None)):
|
|
||||||
return obj
|
return obj
|
||||||
elif obj_type is dict:
|
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
sanitized_items = [
|
|
||||||
(
|
|
||||||
_sanitize_signature_input(key, depth + 1, max_depth, next_seen),
|
|
||||||
_sanitize_signature_input(value, depth + 1, max_depth, next_seen),
|
|
||||||
)
|
|
||||||
for key, value in obj.items()
|
|
||||||
]
|
|
||||||
sanitized_items.sort(
|
|
||||||
key=lambda kv: (
|
|
||||||
_sanitized_sort_key(kv[0], depth + 1, max_depth),
|
|
||||||
_sanitized_sort_key(kv[1], depth + 1, max_depth),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return {key: value for key, value in sanitized_items}
|
|
||||||
elif obj_type is list:
|
|
||||||
return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj]
|
|
||||||
elif obj_type is tuple:
|
|
||||||
return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
|
|
||||||
elif obj_type is set:
|
|
||||||
return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj}
|
|
||||||
elif obj_type is frozenset:
|
|
||||||
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
|
|
||||||
else:
|
|
||||||
# Execution-cache signatures should be built from prompt-safe values.
|
|
||||||
# If a custom node injects a runtime object here, mark it unhashable so
|
|
||||||
# the node won't reuse stale cache entries across runs, but do not walk
|
|
||||||
# the foreign object and risk crashing on custom container semantics.
|
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
def to_hashable(obj, depth=0, max_depth=32, seen=None):
|
obj_id = id(obj)
|
||||||
|
if obj_id in memo:
|
||||||
|
return memo[obj_id]
|
||||||
|
if obj_id in active:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
budget["remaining"] -= 1
|
||||||
|
if budget["remaining"] < 0:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
active.add(obj_id)
|
||||||
|
try:
|
||||||
|
if obj_type is dict:
|
||||||
|
sort_memo = {}
|
||||||
|
sanitized_items = [
|
||||||
|
(
|
||||||
|
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
|
||||||
|
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
|
||||||
|
)
|
||||||
|
for key, value in obj.items()
|
||||||
|
]
|
||||||
|
ordered_items = [
|
||||||
|
(
|
||||||
|
(
|
||||||
|
_sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo),
|
||||||
|
_sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo),
|
||||||
|
),
|
||||||
|
(key, value),
|
||||||
|
)
|
||||||
|
for key, value in sanitized_items
|
||||||
|
]
|
||||||
|
ordered_items.sort(key=lambda item: item[0])
|
||||||
|
|
||||||
|
result = Unhashable()
|
||||||
|
for index in range(1, len(ordered_items)):
|
||||||
|
previous_sort_key, previous_item = ordered_items[index - 1]
|
||||||
|
current_sort_key, current_item = ordered_items[index]
|
||||||
|
if previous_sort_key == current_sort_key and previous_item != current_item:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
result = {key: value for _, (key, value) in ordered_items}
|
||||||
|
elif obj_type is list:
|
||||||
|
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj]
|
||||||
|
elif obj_type is tuple:
|
||||||
|
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj)
|
||||||
|
elif obj_type is set:
|
||||||
|
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj}
|
||||||
|
else:
|
||||||
|
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj)
|
||||||
|
finally:
|
||||||
|
active.discard(obj_id)
|
||||||
|
|
||||||
|
memo[obj_id] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||||
"""Convert sanitized prompt inputs into a stable hashable representation.
|
"""Convert sanitized prompt inputs into a stable hashable representation.
|
||||||
|
|
||||||
Preserves built-in container type and stops safely on cycles or excessive depth.
|
The input is expected to already be sanitized to plain built-in containers,
|
||||||
|
but this function still fails safe for anything unexpected. Traversal is
|
||||||
|
iterative and memoized so shared built-in substructures do not trigger
|
||||||
|
exponential re-walks during cache-key construction.
|
||||||
"""
|
"""
|
||||||
if depth >= max_depth:
|
|
||||||
return Unhashable()
|
|
||||||
|
|
||||||
if seen is None:
|
|
||||||
seen = set()
|
|
||||||
|
|
||||||
# Restrict recursion to plain built-in containers. Some custom nodes insert
|
|
||||||
# runtime objects into prompt inputs for dynamic graph paths; walking those
|
|
||||||
# objects as generic Mappings / Sequences is unsafe and can destabilize the
|
|
||||||
# cache signature builder.
|
|
||||||
obj_type = type(obj)
|
obj_type = type(obj)
|
||||||
if obj_type in (int, float, str, bool, bytes, type(None)):
|
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
|
||||||
return obj
|
return obj
|
||||||
|
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
if obj_type in (dict, list, tuple, set, frozenset):
|
|
||||||
obj_id = id(obj)
|
|
||||||
if obj_id in seen:
|
|
||||||
return Unhashable()
|
|
||||||
seen = seen | {obj_id}
|
|
||||||
|
|
||||||
if obj_type is dict:
|
|
||||||
return (
|
|
||||||
"dict",
|
|
||||||
frozenset(
|
|
||||||
(
|
|
||||||
to_hashable(k, depth + 1, max_depth, seen),
|
|
||||||
to_hashable(v, depth + 1, max_depth, seen),
|
|
||||||
)
|
|
||||||
for k, v in obj.items()
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif obj_type is list:
|
|
||||||
return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
|
||||||
elif obj_type is tuple:
|
|
||||||
return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
|
||||||
elif obj_type is set:
|
|
||||||
return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
|
||||||
elif obj_type is frozenset:
|
|
||||||
return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
|
||||||
else:
|
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
|
memo = {}
|
||||||
|
active = set()
|
||||||
|
sort_memo = {}
|
||||||
|
processed = 0
|
||||||
|
stack = [(obj, False)]
|
||||||
|
|
||||||
|
def resolve_value(value):
|
||||||
|
"""Resolve a child value from the completed memo table when available."""
|
||||||
|
value_type = type(value)
|
||||||
|
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
|
||||||
|
return value
|
||||||
|
return memo.get(id(value), Unhashable())
|
||||||
|
|
||||||
|
def resolve_unordered_values(current, container_tag):
|
||||||
|
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
|
||||||
|
ordered_items = [
|
||||||
|
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
|
||||||
|
for item in current
|
||||||
|
]
|
||||||
|
ordered_items.sort(key=lambda item: item[0])
|
||||||
|
|
||||||
|
for index in range(1, len(ordered_items)):
|
||||||
|
previous_key, previous_value = ordered_items[index - 1]
|
||||||
|
current_key, current_value = ordered_items[index]
|
||||||
|
if previous_key == current_key and previous_value != current_value:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
return (container_tag, tuple(value for _, value in ordered_items))
|
||||||
|
|
||||||
|
while stack:
|
||||||
|
current, expanded = stack.pop()
|
||||||
|
current_type = type(current)
|
||||||
|
|
||||||
|
if current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable:
|
||||||
|
continue
|
||||||
|
if current_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
|
memo[id(current)] = Unhashable()
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_id = id(current)
|
||||||
|
if current_id in memo:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if expanded:
|
||||||
|
active.discard(current_id)
|
||||||
|
if current_type is dict:
|
||||||
|
memo[current_id] = (
|
||||||
|
"dict",
|
||||||
|
tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()),
|
||||||
|
)
|
||||||
|
elif current_type is list:
|
||||||
|
memo[current_id] = ("list", tuple(resolve_value(item) for item in current))
|
||||||
|
elif current_type is tuple:
|
||||||
|
memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current))
|
||||||
|
elif current_type is set:
|
||||||
|
memo[current_id] = resolve_unordered_values(current, "set")
|
||||||
|
else:
|
||||||
|
memo[current_id] = resolve_unordered_values(current, "frozenset")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_id in active:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed += 1
|
||||||
|
if processed > max_nodes:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
active.add(current_id)
|
||||||
|
stack.append((current, True))
|
||||||
|
if current_type is dict:
|
||||||
|
items = list(current.items())
|
||||||
|
for key, value in reversed(items):
|
||||||
|
stack.append((value, False))
|
||||||
|
stack.append((key, False))
|
||||||
|
else:
|
||||||
|
items = list(current)
|
||||||
|
for item in reversed(items):
|
||||||
|
stack.append((item, False))
|
||||||
|
|
||||||
|
return memo.get(id(obj), Unhashable())
|
||||||
|
|
||||||
class CacheKeySetID(CacheKeySet):
|
class CacheKeySetID(CacheKeySet):
|
||||||
"""Cache-key strategy that keys nodes by node id and class type."""
|
"""Cache-key strategy that keys nodes by node id and class type."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
@ -238,6 +341,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||||
for ancestor_id in ancestors:
|
for ancestor_id in ancestors:
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||||
|
signature = _sanitize_signature_input(signature)
|
||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
|||||||
176
tests-unit/execution_test/caching_test.py
Normal file
176
tests-unit/execution_test/caching_test.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
"""Unit tests for cache-signature sanitization and hash conversion hardening."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyNode:
|
||||||
|
"""Minimal node stub used to satisfy cache-signature class lookups."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def INPUT_TYPES():
|
||||||
|
"""Return a minimal empty input schema for unit tests."""
|
||||||
|
return {"required": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDynPrompt:
|
||||||
|
"""Small DynamicPrompt stand-in with only the methods these tests need."""
|
||||||
|
|
||||||
|
def __init__(self, nodes_by_id):
|
||||||
|
"""Store test nodes by id."""
|
||||||
|
self._nodes_by_id = nodes_by_id
|
||||||
|
|
||||||
|
def has_node(self, node_id):
|
||||||
|
"""Return whether the fake prompt contains the requested node."""
|
||||||
|
return node_id in self._nodes_by_id
|
||||||
|
|
||||||
|
def get_node(self, node_id):
|
||||||
|
"""Return the stored node payload for the requested id."""
|
||||||
|
return self._nodes_by_id[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeIsChangedCache:
|
||||||
|
"""Async stub for `is_changed` lookups used by cache-key generation."""
|
||||||
|
|
||||||
|
def __init__(self, values):
|
||||||
|
"""Store canned `is_changed` responses keyed by node id."""
|
||||||
|
self._values = values
|
||||||
|
|
||||||
|
async def get(self, node_id):
|
||||||
|
"""Return the canned `is_changed` value for a node."""
|
||||||
|
return self._values[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class _OpaqueValue:
|
||||||
|
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_unhashable(value, unhashable_type):
|
||||||
|
"""Return whether a nested built-in structure contains an Unhashable sentinel."""
|
||||||
|
if isinstance(value, unhashable_type):
|
||||||
|
return True
|
||||||
|
|
||||||
|
value_type = type(value)
|
||||||
|
if value_type is dict:
|
||||||
|
return any(
|
||||||
|
_contains_unhashable(key, unhashable_type) or _contains_unhashable(item, unhashable_type)
|
||||||
|
for key, item in value.items()
|
||||||
|
)
|
||||||
|
if value_type in (list, tuple, set, frozenset):
|
||||||
|
return any(_contains_unhashable(item, unhashable_type) for item in value)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def caching_module(monkeypatch):
|
||||||
|
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
||||||
|
torch_module = types.ModuleType("torch")
|
||||||
|
psutil_module = types.ModuleType("psutil")
|
||||||
|
nodes_module = types.ModuleType("nodes")
|
||||||
|
nodes_module.NODE_CLASS_MAPPINGS = {}
|
||||||
|
graph_module = types.ModuleType("comfy_execution.graph")
|
||||||
|
|
||||||
|
class DynamicPrompt:
|
||||||
|
"""Placeholder graph type so the caching module can import cleanly."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
graph_module.DynamicPrompt = DynamicPrompt
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "torch", torch_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
|
||||||
|
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
|
||||||
|
|
||||||
|
module = importlib.import_module("comfy_execution.caching")
|
||||||
|
module = importlib.reload(module)
|
||||||
|
return module, nodes_module
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_module):
|
||||||
|
"""Shared built-in substructures should sanitize without collapsing to Unhashable."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
shared = [{"value": 1}, {"value": 2}]
|
||||||
|
|
||||||
|
sanitized = caching._sanitize_signature_input([shared, shared])
|
||||||
|
|
||||||
|
assert isinstance(sanitized, list)
|
||||||
|
assert sanitized[0] == sanitized[1]
|
||||||
|
assert sanitized[0][0]["value"] == 1
|
||||||
|
assert sanitized[0][1]["value"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||||
|
"""Repeated sanitized content should hash stably for shared substructures."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
shared = [{"value": 1}, {"value": 2}]
|
||||||
|
|
||||||
|
sanitized = caching._sanitize_signature_input([shared, shared])
|
||||||
|
hashable = caching.to_hashable(sanitized)
|
||||||
|
|
||||||
|
assert hashable[0] == "list"
|
||||||
|
assert hashable[1][0] == hashable[1][1]
|
||||||
|
assert hashable[1][0][0] == "list"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module):
|
||||||
|
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
ambiguous = {
|
||||||
|
_OpaqueValue(): _OpaqueValue(),
|
||||||
|
_OpaqueValue(): _OpaqueValue(),
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized = caching._sanitize_signature_input(ambiguous)
|
||||||
|
|
||||||
|
assert isinstance(sanitized, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"container_factory",
|
||||||
|
[
|
||||||
|
set,
|
||||||
|
frozenset,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory):
|
||||||
|
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
container = container_factory({_OpaqueValue(), _OpaqueValue()})
|
||||||
|
|
||||||
|
hashable = caching.to_hashable(container)
|
||||||
|
|
||||||
|
assert isinstance(hashable, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_signature_sanitizes_full_signature(caching_module, monkeypatch):
|
||||||
|
"""Recursive `is_changed` payloads should be sanitized inside the full node signature."""
|
||||||
|
caching, nodes_module = caching_module
|
||||||
|
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||||
|
|
||||||
|
is_changed_value = []
|
||||||
|
is_changed_value.append(is_changed_value)
|
||||||
|
|
||||||
|
dynprompt = _FakeDynPrompt(
|
||||||
|
{
|
||||||
|
"node": {
|
||||||
|
"class_type": "UnitTestNode",
|
||||||
|
"inputs": {"value": 5},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
key_set = caching.CacheKeySetInputSignature(
|
||||||
|
dynprompt,
|
||||||
|
["node"],
|
||||||
|
_FakeIsChangedCache({"node": is_changed_value}),
|
||||||
|
)
|
||||||
|
|
||||||
|
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||||
|
|
||||||
|
assert signature[0] == "list"
|
||||||
|
assert _contains_unhashable(signature, caching.Unhashable)
|
||||||
Loading…
Reference in New Issue
Block a user