Compare commits

...

12 Commits

Author SHA1 Message Date
xmarre
9c8eb96254
Merge 2bea0ee5d7 into 16cd8d8a8f 2026-03-14 11:42:08 +00:00
xmarre
2bea0ee5d7 Simplify Unhashable sentinel implementation 2026-03-14 12:42:04 +01:00
xmarre
17863f603a Add comprehensive docstrings for cache key helpers 2026-03-14 12:26:27 +01:00
xmarre
31ba844624 Add cycle detection to signature input sanitization 2026-03-14 12:04:31 +01:00
xmarre
1451001f64 Add docstrings for cache signature hardening helpers 2026-03-14 10:57:45 +01:00
xmarre
1af99b2e81 Update caching hash recursion 2026-03-14 10:31:07 +01:00
xmarre
3568b82b76 Revert "Add missing docstrings"
This reverts commit 4b431ffc27.
2026-03-14 10:11:35 +01:00
xmarre
6728d4d439 Revert "Harden to_hashable against cycles"
This reverts commit 880b51ac4f.
2026-03-14 10:11:04 +01:00
xmarre
4b431ffc27 Add missing docstrings 2026-03-14 09:57:22 +01:00
xmarre
880b51ac4f Harden to_hashable against cycles 2026-03-14 09:46:27 +01:00
xmarre
4d9516b909 Fix caching sanitization logic 2026-03-14 07:06:39 +01:00
xmarre
39086890e2 Fix sanitize_signature_input 2026-03-14 06:56:49 +01:00
2 changed files with 81 additions and 16 deletions

View File

@ -16,6 +16,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
"""Return whether a node class includes UNIQUE_ID among its hidden inputs."""
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -23,35 +24,44 @@ def include_unique_id_in_input(class_type: str) -> bool:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet(ABC):
"""Base helper for building and storing cache keys for prompt nodes."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
"""Initialize cache-key storage for a dynamic prompt execution pass."""
self.keys = {}
self.subcache_keys = {}
@abstractmethod
async def add_keys(self, node_ids):
"""Populate cache keys for the provided node ids."""
raise NotImplementedError()
def all_node_ids(self):
"""Return the set of node ids currently tracked by this key set."""
return set(self.keys.keys())
def get_used_keys(self):
"""Return the computed cache keys currently in use."""
return self.keys.values()
def get_used_subcache_keys(self):
"""Return the computed subcache keys currently in use."""
return self.subcache_keys.values()
def get_data_key(self, node_id):
"""Return the cache key for a node, if present."""
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
"""Return the subcache key for a node, if present."""
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
"""Hashable identity sentinel for values that cannot be represented safely in cache keys."""
pass
def _sanitized_sort_key(obj, depth=0, max_depth=32):
"""Return a deterministic ordering key for sanitized built-in container content."""
if depth >= max_depth:
return ("MAX_DEPTH",)
@ -82,18 +92,32 @@ def _sanitized_sort_key(obj, depth=0, max_depth=32):
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
def _sanitize_signature_input(obj, depth=0, max_depth=32):
def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None):
"""Normalize signature inputs to safe built-in containers.
Preserves built-in container type, replaces opaque runtime values with
Unhashable(), and stops safely on cycles or excessive depth.
"""
if depth >= max_depth:
return Unhashable()
if seen is None:
seen = set()
obj_type = type(obj)
if obj_type in (dict, list, tuple, set, frozenset):
obj_id = id(obj)
if obj_id in seen:
return Unhashable()
next_seen = seen | {obj_id}
if obj_type in (int, float, str, bool, bytes, type(None)):
return obj
elif obj_type is dict:
sanitized_items = [
(
_sanitize_signature_input(key, depth + 1, max_depth),
_sanitize_signature_input(value, depth + 1, max_depth),
_sanitize_signature_input(key, depth + 1, max_depth, next_seen),
_sanitize_signature_input(value, depth + 1, max_depth, next_seen),
)
for key, value in obj.items()
]
@ -105,13 +129,13 @@ def _sanitize_signature_input(obj, depth=0, max_depth=32):
)
return {key: value for key, value in sanitized_items}
elif obj_type is list:
return [_sanitize_signature_input(item, depth + 1, max_depth) for item in obj]
return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj]
elif obj_type is tuple:
return tuple(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
elif obj_type is set:
return {_sanitize_signature_input(item, depth + 1, max_depth) for item in obj}
return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj}
elif obj_type is frozenset:
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
else:
# Execution-cache signatures should be built from prompt-safe values.
# If a custom node injects a runtime object here, mark it unhashable so
@ -119,7 +143,17 @@ def _sanitize_signature_input(obj, depth=0, max_depth=32):
# the foreign object and risk crashing on custom container semantics.
return Unhashable()
def to_hashable(obj):
def to_hashable(obj, depth=0, max_depth=32, seen=None):
"""Convert sanitized prompt inputs into a stable hashable representation.
Preserves built-in container type and stops safely on cycles or excessive depth.
"""
if depth >= max_depth:
return Unhashable()
if seen is None:
seen = set()
# Restrict recursion to plain built-in containers. Some custom nodes insert
# runtime objects into prompt inputs for dynamic graph paths; walking those
# objects as generic Mappings / Sequences is unsafe and can destabilize the
@ -127,25 +161,44 @@ def to_hashable(obj):
obj_type = type(obj)
if obj_type in (int, float, str, bool, bytes, type(None)):
return obj
elif obj_type is dict:
return ("dict", frozenset((to_hashable(k), to_hashable(v)) for k, v in obj.items()))
if obj_type in (dict, list, tuple, set, frozenset):
obj_id = id(obj)
if obj_id in seen:
return Unhashable()
seen = seen | {obj_id}
if obj_type is dict:
return (
"dict",
frozenset(
(
to_hashable(k, depth + 1, max_depth, seen),
to_hashable(v, depth + 1, max_depth, seen),
)
for k, v in obj.items()
),
)
elif obj_type is list:
return ("list", tuple(to_hashable(i) for i in obj))
return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(to_hashable(i) for i in obj))
return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is set:
return ("set", frozenset(to_hashable(i) for i in obj))
return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is frozenset:
return ("frozenset", frozenset(to_hashable(i) for i in obj))
return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
else:
return Unhashable()
class CacheKeySetID(CacheKeySet):
"""Cache-key strategy that keys nodes by node id and class type."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
"""Initialize identity-based cache keys for the supplied dynamic prompt."""
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
async def add_keys(self, node_ids):
"""Populate identity-based keys for nodes that exist in the dynamic prompt."""
for node_id in node_ids:
if node_id in self.keys:
continue
@ -156,15 +209,19 @@ class CacheKeySetID(CacheKeySet):
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
"""Cache-key strategy that hashes a node's immediate inputs plus ancestor references."""
def __init__(self, dynprompt, node_ids, is_changed_cache):
"""Initialize input-signature-based cache keys for the supplied dynamic prompt."""
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
def include_node_id_in_input(self) -> bool:
"""Return whether node ids should be included in computed input signatures."""
return False
async def add_keys(self, node_ids):
"""Populate input-signature-based keys for nodes in the dynamic prompt."""
for node_id in node_ids:
if node_id in self.keys:
continue
@ -175,6 +232,7 @@ class CacheKeySetInputSignature(CacheKeySet):
self.subcache_keys[node_id] = (node_id, node["class_type"])
async def get_node_signature(self, dynprompt, node_id):
"""Build the full cache signature for a node and its ordered ancestors."""
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
@ -183,6 +241,10 @@ class CacheKeySetInputSignature(CacheKeySet):
return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
"""Build the cache-signature fragment for a node's immediate inputs.
Link inputs are reduced to ancestor references, while raw values are sanitized first.
"""
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
@ -205,12 +267,14 @@ class CacheKeySetInputSignature(CacheKeySet):
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
"""Return ancestors in deterministic traversal order and their index mapping."""
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
"""Recursively collect ancestors in input order without revisiting prior nodes."""
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]

View File

@ -1,4 +1,5 @@
def is_link(obj):
"""Return whether obj is a plain prompt link of the form [node_id, output_index]."""
# Prompt links produced by the frontend / GraphBuilder are plain Python
# lists in the form [node_id, output_index]. Some custom-node paths can
# inject foreign runtime objects into prompt inputs during on-prompt graph