Prevent redundant signature rewalk

This commit is contained in:
xmarre 2026-03-16 13:31:02 +01:00
parent bff714dda0
commit 6158cd5820
2 changed files with 77 additions and 17 deletions

View File

@ -461,18 +461,18 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return _signature_to_hashable(signature)
return tuple(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
"""Build the immediate cache-signature fragment for a node.
Link inputs are reduced to ancestor references here. Non-link values
are canonicalized or failed closed before being appended so the outer
node-signature pass never recurses into live prompt input containers.
are canonicalized or failed closed before being appended so the final
node signature is assembled from already-hashable fragments.
"""
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
return (float("NaN"),)
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -487,7 +487,7 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, to_hashable(inputs[key])))
return signature
return tuple(signature)
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.

View File

@ -42,11 +42,11 @@ def test_get_immediate_node_signature_canonicalizes_non_link_inputs(monkeypatch)
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
assert signature == [
assert signature == (
"TestCacheNode",
None,
("value", ("list", (1, ("dict", (("nested", ("list", (2, 3))),))))),
]
)
def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(monkeypatch):
@ -69,7 +69,7 @@ def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(mon
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
assert signature[:2] == ["TestCacheNode", None]
assert signature[:2] == ("TestCacheNode", None)
assert signature[2][0] == "value"
assert type(signature[2][1]) is caching.Unhashable
@ -87,17 +87,77 @@ def test_get_node_signature_never_visits_raw_non_link_input(monkeypatch):
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
original_impl = caching._signature_to_hashable_impl
def guarded_impl(obj, *args, **kwargs):
if obj is live_value:
raise AssertionError("raw non-link input reached outer signature canonicalizer")
return original_impl(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", guarded_impl)
monkeypatch.setattr(
caching,
"_signature_to_hashable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
AssertionError("outer signature canonicalizer should not run")
),
)
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, tuple)
def test_get_node_signature_keeps_deep_canonicalized_input_fragment(monkeypatch):
live_value = 1
for _ in range(8):
live_value = [live_value]
expected = caching.to_hashable(live_value)
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, tuple)
assert signature[0][2][0] == "value"
assert signature[0][2][1] == expected
def test_get_node_signature_keeps_large_precanonicalized_fragment(monkeypatch):
live_value = object()
canonical_fragment = ("tuple", tuple(("list", (index, index + 1)) for index in range(256)))
dynprompt = _StubDynPrompt(
{
"1": {
"class_type": "TestCacheNode",
"inputs": {"value": live_value},
}
}
)
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
monkeypatch.setattr(
caching,
"to_hashable",
lambda value, max_nodes=caching._MAX_SIGNATURE_CONTAINER_VISITS: (
canonical_fragment if value is live_value else caching.Unhashable()
),
)
monkeypatch.setattr(
caching,
"_signature_to_hashable",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
AssertionError("outer signature canonicalizer should not run")
),
)
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
assert isinstance(signature, tuple)
assert signature[0][2] == ("value", canonical_fragment)