mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-20 16:43:45 +08:00
Compare commits
4 Commits
98a7e14aa4
...
7062be136d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7062be136d | ||
|
|
088778c35d | ||
|
|
4c5f82971e | ||
|
|
f1d91a4c8c |
@ -67,6 +67,22 @@ _MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||
_FAILED_SIGNATURE = object()
|
||||
|
||||
|
||||
def _shallow_is_changed_signature(value):
|
||||
"""Sanitize execution-time `is_changed` values without deep recursion."""
|
||||
value_type = type(value)
|
||||
if value_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return value
|
||||
if value_type is list or value_type is tuple:
|
||||
try:
|
||||
items = tuple(value)
|
||||
except RuntimeError:
|
||||
return Unhashable()
|
||||
if all(type(item) in _PRIMITIVE_SIGNATURE_TYPES for item in items):
|
||||
container_tag = "is_changed_list" if value_type is list else "is_changed_tuple"
|
||||
return (container_tag, items)
|
||||
return Unhashable()
|
||||
|
||||
|
||||
def _primitive_signature_sort_key(obj):
|
||||
"""Return a deterministic ordering key for primitive signature values."""
|
||||
obj_type = type(obj)
|
||||
@ -230,7 +246,10 @@ def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, ac
|
||||
|
||||
def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
"""Build the final cache-signature representation in one fail-closed pass."""
|
||||
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
|
||||
try:
|
||||
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
|
||||
except RuntimeError:
|
||||
return Unhashable()
|
||||
if result is _FAILED_SIGNATURE:
|
||||
return Unhashable()
|
||||
return result[0]
|
||||
@ -264,6 +283,10 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
return value
|
||||
return memo.get(id(value), Unhashable())
|
||||
|
||||
def is_failed(value):
|
||||
"""Return whether a resolved child value represents failed canonicalization."""
|
||||
return type(value) is Unhashable
|
||||
|
||||
def resolve_unordered_values(current_items, container_tag):
|
||||
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
|
||||
try:
|
||||
@ -271,6 +294,8 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
|
||||
for item in current_items
|
||||
]
|
||||
if any(is_failed(value) for _, value in ordered_items):
|
||||
return Unhashable()
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
except RuntimeError:
|
||||
return Unhashable()
|
||||
@ -304,20 +329,41 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current.items())
|
||||
memo[current_id] = (
|
||||
"dict",
|
||||
tuple((resolve_value(k), resolve_value(v)) for k, v in items),
|
||||
)
|
||||
ordered_items = [
|
||||
(_sanitized_sort_key(k, memo=sort_memo), resolve_value(k), resolve_value(v))
|
||||
for k, v in items
|
||||
]
|
||||
if any(is_failed(key) or is_failed(value) for _, key, value in ordered_items):
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
for index in range(1, len(ordered_items)):
|
||||
if ordered_items[index - 1][0] == ordered_items[index][0]:
|
||||
memo[current_id] = Unhashable()
|
||||
break
|
||||
else:
|
||||
memo[current_id] = (
|
||||
"dict",
|
||||
tuple((key, value) for _, key, value in ordered_items),
|
||||
)
|
||||
elif current_type is list:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current)
|
||||
memo[current_id] = ("list", tuple(resolve_value(item) for item in items))
|
||||
resolved_items = tuple(resolve_value(item) for item in items)
|
||||
if any(is_failed(item) for item in resolved_items):
|
||||
memo[current_id] = Unhashable()
|
||||
else:
|
||||
memo[current_id] = ("list", resolved_items)
|
||||
elif current_type is tuple:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current)
|
||||
memo[current_id] = ("tuple", tuple(resolve_value(item) for item in items))
|
||||
resolved_items = tuple(resolve_value(item) for item in items)
|
||||
if any(is_failed(item) for item in resolved_items):
|
||||
memo[current_id] = Unhashable()
|
||||
else:
|
||||
memo[current_id] = ("tuple", resolved_items)
|
||||
elif current_type is set:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
@ -429,7 +475,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
||||
signature = [class_type, _shallow_is_changed_signature(await self.is_changed_cache.get(node_id))]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
|
||||
@ -205,6 +205,26 @@ def test_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||
assert hashable[1][0][0] == "list"
|
||||
|
||||
|
||||
def test_to_hashable_fails_closed_for_ordered_container_with_opaque_child(caching_module):
|
||||
"""Ordered containers should fail closed when a child cannot be canonicalized."""
|
||||
caching, _ = caching_module
|
||||
|
||||
result = caching.to_hashable([object()])
|
||||
|
||||
assert isinstance(result, caching.Unhashable)
|
||||
|
||||
|
||||
def test_to_hashable_canonicalizes_dict_insertion_order(caching_module):
|
||||
"""Dicts with the same content should hash identically regardless of insertion order."""
|
||||
caching, _ = caching_module
|
||||
|
||||
first = {"b": 2, "a": 1}
|
||||
second = {"a": 1, "b": 2}
|
||||
|
||||
assert caching.to_hashable(first) == ("dict", (("a", 1), ("b", 2)))
|
||||
assert caching.to_hashable(first) == caching.to_hashable(second)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
@ -227,6 +247,19 @@ def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, c
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module):
|
||||
"""Ambiguous dict key ordering should fail closed instead of using insertion order."""
|
||||
caching, _ = caching_module
|
||||
ambiguous = {
|
||||
_OpaqueValue(): 1,
|
||||
_OpaqueValue(): 2,
|
||||
}
|
||||
|
||||
hashable = caching.to_hashable(ambiguous)
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module):
|
||||
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
|
||||
caching, _ = caching_module
|
||||
@ -309,3 +342,47 @@ def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(c
|
||||
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_accepts_primitive_lists(caching_module):
|
||||
"""Primitive-only `is_changed` lists should stay hashable without deep descent."""
|
||||
caching, _ = caching_module
|
||||
|
||||
sanitized = caching._shallow_is_changed_signature([1, "two", None, True])
|
||||
|
||||
assert sanitized == ("is_changed_list", (1, "two", None, True))
|
||||
|
||||
|
||||
def test_shallow_is_changed_signature_fails_closed_on_nested_containers(caching_module):
|
||||
"""Nested containers from `is_changed` should be rejected immediately."""
|
||||
caching, _ = caching_module
|
||||
|
||||
sanitized = caching._shallow_is_changed_signature([1, ["nested"]])
|
||||
|
||||
assert isinstance(sanitized, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_immediate_node_signature_marks_recursive_is_changed_unhashable(caching_module, monkeypatch):
|
||||
"""Recursive `is_changed` payloads should be cut off before signature canonicalization."""
|
||||
caching, nodes_module = caching_module
|
||||
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||
|
||||
is_changed_value = []
|
||||
is_changed_value.append(is_changed_value)
|
||||
dynprompt = _FakeDynPrompt(
|
||||
{
|
||||
"node": {
|
||||
"class_type": "UnitTestNode",
|
||||
"inputs": {"value": 5},
|
||||
}
|
||||
}
|
||||
)
|
||||
key_set = caching.CacheKeySetInputSignature(
|
||||
dynprompt,
|
||||
["node"],
|
||||
_FakeIsChangedCache({"node": is_changed_value}),
|
||||
)
|
||||
|
||||
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "node", {}))
|
||||
|
||||
assert isinstance(signature[1], caching.Unhashable)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user