mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-22 17:43:33 +08:00
Merge fdcc38b9ea into 1a157e1f97
This commit is contained in:
commit
3d091d7797
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
|
||||||
import psutil
|
import psutil
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
@ -17,6 +16,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
|||||||
|
|
||||||
|
|
||||||
def include_unique_id_in_input(class_type: str) -> bool:
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
"""Return whether a node class includes UNIQUE_ID among its hidden inputs."""
|
||||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
@ -24,52 +24,410 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
|||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet(ABC):
|
class CacheKeySet(ABC):
|
||||||
|
"""Base helper for building and storing cache keys for prompt nodes."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""Initialize cache-key storage for a dynamic prompt execution pass."""
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.subcache_keys = {}
|
self.subcache_keys = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Populate cache keys for the provided node ids."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
|
"""Return the set of node ids currently tracked by this key set."""
|
||||||
return set(self.keys.keys())
|
return set(self.keys.keys())
|
||||||
|
|
||||||
def get_used_keys(self):
|
def get_used_keys(self):
|
||||||
|
"""Return the computed cache keys currently in use."""
|
||||||
return self.keys.values()
|
return self.keys.values()
|
||||||
|
|
||||||
def get_used_subcache_keys(self):
|
def get_used_subcache_keys(self):
|
||||||
|
"""Return the computed subcache keys currently in use."""
|
||||||
return self.subcache_keys.values()
|
return self.subcache_keys.values()
|
||||||
|
|
||||||
def get_data_key(self, node_id):
|
def get_data_key(self, node_id):
|
||||||
|
"""Return the cache key for a node, if present."""
|
||||||
return self.keys.get(node_id, None)
|
return self.keys.get(node_id, None)
|
||||||
|
|
||||||
def get_subcache_key(self, node_id):
|
def get_subcache_key(self, node_id):
|
||||||
|
"""Return the subcache key for a node, if present."""
|
||||||
return self.subcache_keys.get(node_id, None)
|
return self.subcache_keys.get(node_id, None)
|
||||||
|
|
||||||
class Unhashable:
|
class Unhashable:
|
||||||
def __init__(self):
|
"""Hashable identity sentinel for values that cannot be represented safely in cache keys."""
|
||||||
self.value = float("NaN")
|
pass
|
||||||
|
|
||||||
def to_hashable(obj):
|
|
||||||
# So that we don't infinitely recurse since frozenset and tuples
|
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
|
||||||
# are Sequences.
|
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
|
||||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
_MAX_SIGNATURE_DEPTH = 32
|
||||||
|
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||||
|
_FAILED_SIGNATURE = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _shallow_is_changed_signature(value):
|
||||||
|
"""Sanitize execution-time `is_changed` values with a small fail-closed budget."""
|
||||||
|
value_type = type(value)
|
||||||
|
if value_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
|
return value
|
||||||
|
|
||||||
|
canonical = to_hashable(value, max_nodes=64)
|
||||||
|
if type(canonical) is Unhashable:
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
if value_type is list or value_type is tuple:
|
||||||
|
container_tag = "is_changed_list" if value_type is list else "is_changed_tuple"
|
||||||
|
return (container_tag, canonical[1])
|
||||||
|
|
||||||
|
return canonical
|
||||||
|
|
||||||
|
|
||||||
|
def _primitive_signature_sort_key(obj):
|
||||||
|
"""Return a deterministic ordering key for primitive signature values."""
|
||||||
|
obj_type = type(obj)
|
||||||
|
return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
||||||
|
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||||
|
if depth >= max_depth:
|
||||||
|
return ("MAX_DEPTH",)
|
||||||
|
|
||||||
|
if active is None:
|
||||||
|
active = set()
|
||||||
|
if memo is None:
|
||||||
|
memo = {}
|
||||||
|
|
||||||
|
obj_type = type(obj)
|
||||||
|
if obj_type is Unhashable:
|
||||||
|
return ("UNHASHABLE",)
|
||||||
|
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
|
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||||
|
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
|
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
|
||||||
|
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in memo:
|
||||||
|
return memo[obj_id]
|
||||||
|
if obj_id in active:
|
||||||
|
return ("CYCLE",)
|
||||||
|
|
||||||
|
active.add(obj_id)
|
||||||
|
try:
|
||||||
|
if obj_type is dict:
|
||||||
|
items = [
|
||||||
|
(
|
||||||
|
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
|
||||||
|
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
|
||||||
|
)
|
||||||
|
for k, v in obj.items()
|
||||||
|
]
|
||||||
|
items.sort()
|
||||||
|
result = ("dict", tuple(items))
|
||||||
|
elif obj_type is list:
|
||||||
|
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||||
|
elif obj_type is tuple:
|
||||||
|
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||||
|
elif obj_type is set:
|
||||||
|
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||||
|
else:
|
||||||
|
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||||
|
finally:
|
||||||
|
active.discard(obj_id)
|
||||||
|
|
||||||
|
memo[obj_id] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
|
||||||
|
"""Canonicalize signature inputs directly into their final hashable form."""
|
||||||
|
if depth >= max_depth:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
if active is None:
|
||||||
|
active = set()
|
||||||
|
if memo is None:
|
||||||
|
memo = {}
|
||||||
|
if budget is None:
|
||||||
|
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
|
||||||
|
|
||||||
|
obj_type = type(obj)
|
||||||
|
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
|
return obj, _primitive_signature_sort_key(obj)
|
||||||
|
if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in memo:
|
||||||
|
return memo[obj_id]
|
||||||
|
if obj_id in active:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
budget["remaining"] -= 1
|
||||||
|
if budget["remaining"] < 0:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
active.add(obj_id)
|
||||||
|
try:
|
||||||
|
if obj_type is dict:
|
||||||
|
try:
|
||||||
|
items = list(obj.items())
|
||||||
|
except RuntimeError:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
ordered_items = []
|
||||||
|
for key, value in items:
|
||||||
|
if type(key) not in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
key_result = (key, _primitive_signature_sort_key(key))
|
||||||
|
value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget)
|
||||||
|
if value_result is _FAILED_SIGNATURE:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
key_value, key_sort = key_result
|
||||||
|
value_value, value_sort = value_result
|
||||||
|
ordered_items.append((key_sort, value_sort, key_value, value_value))
|
||||||
|
|
||||||
|
ordered_items.sort(key=lambda item: (item[0], item[1]))
|
||||||
|
for index in range(1, len(ordered_items)):
|
||||||
|
previous_key_sort = ordered_items[index - 1][0]
|
||||||
|
current_key_sort = ordered_items[index][0]
|
||||||
|
if previous_key_sort == current_key_sort:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
value = ("dict", tuple((key_value, value_value) for _, _, key_value, value_value in ordered_items))
|
||||||
|
sort_key = ("dict", tuple((key_sort, value_sort) for key_sort, value_sort, _, _ in ordered_items))
|
||||||
|
elif obj_type is list or obj_type is tuple:
|
||||||
|
try:
|
||||||
|
items = list(obj)
|
||||||
|
except RuntimeError:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
child_results = []
|
||||||
|
for item in items:
|
||||||
|
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
|
||||||
|
if child_result is _FAILED_SIGNATURE:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
child_results.append(child_result)
|
||||||
|
|
||||||
|
container_tag = "list" if obj_type is list else "tuple"
|
||||||
|
value = (container_tag, tuple(child for child, _ in child_results))
|
||||||
|
sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
items = list(obj)
|
||||||
|
except RuntimeError:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
ordered_items = []
|
||||||
|
for item in items:
|
||||||
|
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
|
||||||
|
if child_result is _FAILED_SIGNATURE:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
child_value, child_sort = child_result
|
||||||
|
ordered_items.append((child_sort, child_value))
|
||||||
|
|
||||||
|
ordered_items.sort(key=lambda item: item[0])
|
||||||
|
for index in range(1, len(ordered_items)):
|
||||||
|
previous_sort_key, previous_value = ordered_items[index - 1]
|
||||||
|
current_sort_key, current_value = ordered_items[index]
|
||||||
|
if previous_sort_key == current_sort_key and previous_value != current_value:
|
||||||
|
return _FAILED_SIGNATURE
|
||||||
|
|
||||||
|
container_tag = "set" if obj_type is set else "frozenset"
|
||||||
|
value = (container_tag, tuple(child_value for _, child_value in ordered_items))
|
||||||
|
sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items))
|
||||||
|
finally:
|
||||||
|
active.discard(obj_id)
|
||||||
|
|
||||||
|
memo[obj_id] = (value, sort_key)
|
||||||
|
return memo[obj_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||||
|
"""Build the final cache-signature representation in one fail-closed pass."""
|
||||||
|
try:
|
||||||
|
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
|
||||||
|
except RuntimeError:
|
||||||
|
return Unhashable()
|
||||||
|
if result is _FAILED_SIGNATURE:
|
||||||
|
return Unhashable()
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
|
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||||
|
"""Convert sanitized prompt inputs into a stable hashable representation.
|
||||||
|
|
||||||
|
The input is expected to already be sanitized to plain built-in containers,
|
||||||
|
but this function still fails safe for anything unexpected. Traversal is
|
||||||
|
iterative and memoized so shared built-in substructures do not trigger
|
||||||
|
exponential re-walks during cache-key construction.
|
||||||
|
"""
|
||||||
|
obj_type = type(obj)
|
||||||
|
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, Mapping):
|
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
|
||||||
elif isinstance(obj, Sequence):
|
|
||||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
|
||||||
else:
|
|
||||||
# TODO - Support other objects like tensors?
|
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
|
memo = {}
|
||||||
|
active = set()
|
||||||
|
snapshots = {}
|
||||||
|
sort_memo = {}
|
||||||
|
processed = 0
|
||||||
|
# Keep traversal state separate from container snapshots/results.
|
||||||
|
work_stack = [(obj, False)]
|
||||||
|
|
||||||
|
def resolve_value(value):
|
||||||
|
"""Resolve a child value from the completed memo table when available."""
|
||||||
|
value_type = type(value)
|
||||||
|
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
|
||||||
|
return value
|
||||||
|
return memo.get(id(value), Unhashable())
|
||||||
|
|
||||||
|
def is_failed(value):
|
||||||
|
"""Return whether a resolved child value represents failed canonicalization."""
|
||||||
|
return type(value) is Unhashable
|
||||||
|
|
||||||
|
def resolve_unordered_values(current_items, container_tag):
|
||||||
|
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
|
||||||
|
try:
|
||||||
|
ordered_items = [
|
||||||
|
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
|
||||||
|
for item in current_items
|
||||||
|
]
|
||||||
|
if any(is_failed(value) for _, value in ordered_items):
|
||||||
|
return Unhashable()
|
||||||
|
ordered_items.sort(key=lambda item: item[0])
|
||||||
|
except RuntimeError:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
for index in range(1, len(ordered_items)):
|
||||||
|
previous_key, previous_value = ordered_items[index - 1]
|
||||||
|
current_key, current_value = ordered_items[index]
|
||||||
|
if previous_key == current_key and previous_value != current_value:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
return (container_tag, tuple(value for _, value in ordered_items))
|
||||||
|
|
||||||
|
while work_stack:
|
||||||
|
entry = work_stack.pop()
|
||||||
|
if len(entry) == 3:
|
||||||
|
_, current_id, current_type = entry
|
||||||
|
current = None
|
||||||
|
expanded = True
|
||||||
|
else:
|
||||||
|
current, expanded = entry
|
||||||
|
current_type = type(current)
|
||||||
|
current_id = id(current)
|
||||||
|
|
||||||
|
if not expanded and (current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable):
|
||||||
|
continue
|
||||||
|
if not expanded and current_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_id in memo:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if expanded:
|
||||||
|
active.discard(current_id)
|
||||||
|
try:
|
||||||
|
items = snapshots.pop(current_id, None)
|
||||||
|
if items is None:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_type is dict:
|
||||||
|
ordered_items = [
|
||||||
|
(_sanitized_sort_key(k, memo=sort_memo), k, resolve_value(v))
|
||||||
|
for k, v in items
|
||||||
|
]
|
||||||
|
if any(type(key) not in _PRIMITIVE_SIGNATURE_TYPES or is_failed(value) for _, key, value in ordered_items):
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
continue
|
||||||
|
ordered_items.sort(key=lambda item: item[0])
|
||||||
|
for index in range(1, len(ordered_items)):
|
||||||
|
if ordered_items[index - 1][0] == ordered_items[index][0]:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
memo[current_id] = (
|
||||||
|
"dict",
|
||||||
|
tuple((key, value) for _, key, value in ordered_items),
|
||||||
|
)
|
||||||
|
elif current_type is list:
|
||||||
|
resolved_items = tuple(resolve_value(item) for item in items)
|
||||||
|
if any(is_failed(item) for item in resolved_items):
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
else:
|
||||||
|
memo[current_id] = ("list", resolved_items)
|
||||||
|
elif current_type is tuple:
|
||||||
|
resolved_items = tuple(resolve_value(item) for item in items)
|
||||||
|
if any(is_failed(item) for item in resolved_items):
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
else:
|
||||||
|
memo[current_id] = ("tuple", resolved_items)
|
||||||
|
elif current_type is set:
|
||||||
|
memo[current_id] = resolve_unordered_values(items, "set")
|
||||||
|
else:
|
||||||
|
memo[current_id] = resolve_unordered_values(items, "frozenset")
|
||||||
|
except RuntimeError:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if current_id in active:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
continue
|
||||||
|
|
||||||
|
processed += 1
|
||||||
|
if processed > max_nodes:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
active.add(current_id)
|
||||||
|
if current_type is dict:
|
||||||
|
try:
|
||||||
|
items = list(current.items())
|
||||||
|
snapshots[current_id] = items
|
||||||
|
except RuntimeError:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
active.discard(current_id)
|
||||||
|
continue
|
||||||
|
for key, value in items:
|
||||||
|
if type(key) not in _PRIMITIVE_SIGNATURE_TYPES:
|
||||||
|
snapshots.pop(current_id, None)
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
active.discard(current_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
work_stack.append(("EXPANDED", current_id, current_type))
|
||||||
|
for _, value in reversed(items):
|
||||||
|
work_stack.append((value, False))
|
||||||
|
continue
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
items = list(current)
|
||||||
|
snapshots[current_id] = items
|
||||||
|
except RuntimeError:
|
||||||
|
memo[current_id] = Unhashable()
|
||||||
|
active.discard(current_id)
|
||||||
|
continue
|
||||||
|
work_stack.append(("EXPANDED", current_id, current_type))
|
||||||
|
for item in reversed(items):
|
||||||
|
work_stack.append((item, False))
|
||||||
|
|
||||||
|
return memo.get(id(obj), Unhashable())
|
||||||
|
|
||||||
class CacheKeySetID(CacheKeySet):
|
class CacheKeySetID(CacheKeySet):
|
||||||
|
"""Cache-key strategy that keys nodes by node id and class type."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""Initialize identity-based cache keys for the supplied dynamic prompt."""
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Populate identity-based keys for nodes that exist in the dynamic prompt."""
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
@ -80,15 +438,19 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
class CacheKeySetInputSignature(CacheKeySet):
|
class CacheKeySetInputSignature(CacheKeySet):
|
||||||
|
"""Cache-key strategy that hashes a node's immediate inputs plus ancestor references."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""Initialize input-signature-based cache keys for the supplied dynamic prompt."""
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed_cache = is_changed_cache
|
||||||
|
|
||||||
def include_node_id_in_input(self) -> bool:
|
def include_node_id_in_input(self) -> bool:
|
||||||
|
"""Return whether node ids should be included in computed input signatures."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Populate input-signature-based keys for nodes in the dynamic prompt."""
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
@ -99,21 +461,37 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
async def get_node_signature(self, dynprompt, node_id):
|
async def get_node_signature(self, dynprompt, node_id):
|
||||||
|
"""Build the full cache signature for a node and its ordered ancestors."""
|
||||||
signature = []
|
signature = []
|
||||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
immediate = await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)
|
||||||
|
if type(immediate) is Unhashable:
|
||||||
|
return immediate
|
||||||
|
signature.append(immediate)
|
||||||
for ancestor_id in ancestors:
|
for ancestor_id in ancestors:
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
immediate = await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)
|
||||||
return to_hashable(signature)
|
if type(immediate) is Unhashable:
|
||||||
|
return immediate
|
||||||
|
signature.append(immediate)
|
||||||
|
return tuple(signature)
|
||||||
|
|
||||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
"""Build the immediate cache-signature fragment for a node.
|
||||||
|
|
||||||
|
Link inputs are reduced to ancestor references here. Non-link values
|
||||||
|
are canonicalized or failed closed before being appended so the final
|
||||||
|
node signature is assembled from already-hashable fragments.
|
||||||
|
"""
|
||||||
if not dynprompt.has_node(node_id):
|
if not dynprompt.has_node(node_id):
|
||||||
# This node doesn't exist -- we can't cache it.
|
# This node doesn't exist -- we can't cache it.
|
||||||
return [float("NaN")]
|
return Unhashable()
|
||||||
node = dynprompt.get_node(node_id)
|
node = dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
is_changed_signature = _shallow_is_changed_signature(await self.is_changed_cache.get(node_id))
|
||||||
|
if type(is_changed_signature) is Unhashable:
|
||||||
|
return is_changed_signature
|
||||||
|
signature = [class_type, is_changed_signature]
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
signature.append(node_id)
|
signature.append(node_id)
|
||||||
inputs = node["inputs"]
|
inputs = node["inputs"]
|
||||||
@ -123,18 +501,23 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||||
else:
|
else:
|
||||||
signature.append((key, inputs[key]))
|
value_signature = to_hashable(inputs[key])
|
||||||
return signature
|
if type(value_signature) is Unhashable:
|
||||||
|
return value_signature
|
||||||
|
signature.append((key, value_signature))
|
||||||
|
return tuple(signature)
|
||||||
|
|
||||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||||
# deterministic based on which specific inputs the ancestor is connected by.
|
# deterministic based on which specific inputs the ancestor is connected by.
|
||||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||||
|
"""Return ancestors in deterministic traversal order and their index mapping."""
|
||||||
ancestors = []
|
ancestors = []
|
||||||
order_mapping = {}
|
order_mapping = {}
|
||||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||||
return ancestors, order_mapping
|
return ancestors, order_mapping
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
"""Recursively collect ancestors in input order without revisiting prior nodes."""
|
||||||
if not dynprompt.has_node(node_id):
|
if not dynprompt.has_node(node_id):
|
||||||
return
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
|
|||||||
@ -1,11 +1,17 @@
|
|||||||
def is_link(obj):
|
def is_link(obj):
|
||||||
if not isinstance(obj, list):
|
"""Return whether obj is a plain prompt link of the form [node_id, output_index]."""
|
||||||
|
# Prompt links produced by the frontend / GraphBuilder are plain Python
|
||||||
|
# lists in the form [node_id, output_index]. Some custom-node paths can
|
||||||
|
# inject foreign runtime objects into prompt inputs during on-prompt graph
|
||||||
|
# rewriting or subgraph construction. Be strict here so cache signature
|
||||||
|
# building never tries to treat list-like proxy objects as links.
|
||||||
|
if type(obj) is not list:
|
||||||
return False
|
return False
|
||||||
if len(obj) != 2:
|
if len(obj) != 2:
|
||||||
return False
|
return False
|
||||||
if not isinstance(obj[0], str):
|
if type(obj[0]) is not str:
|
||||||
return False
|
return False
|
||||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
if type(obj[1]) is not int:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
473
tests-unit/execution_test/caching_test.py
Normal file
473
tests-unit/execution_test/caching_test.py
Normal file
@ -0,0 +1,473 @@
|
|||||||
|
"""Unit tests for cache-signature canonicalization hardening."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyNode:
|
||||||
|
"""Minimal node stub used to satisfy cache-signature class lookups."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def INPUT_TYPES():
|
||||||
|
"""Return a minimal empty input schema for unit tests."""
|
||||||
|
return {"required": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDynPrompt:
|
||||||
|
"""Small DynamicPrompt stand-in with only the methods these tests need."""
|
||||||
|
|
||||||
|
def __init__(self, nodes_by_id):
|
||||||
|
"""Store test nodes by id."""
|
||||||
|
self._nodes_by_id = nodes_by_id
|
||||||
|
|
||||||
|
def has_node(self, node_id):
|
||||||
|
"""Return whether the fake prompt contains the requested node."""
|
||||||
|
return node_id in self._nodes_by_id
|
||||||
|
|
||||||
|
def get_node(self, node_id):
|
||||||
|
"""Return the stored node payload for the requested id."""
|
||||||
|
return self._nodes_by_id[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeIsChangedCache:
|
||||||
|
"""Async stub for `is_changed` lookups used by cache-key generation."""
|
||||||
|
|
||||||
|
def __init__(self, values):
|
||||||
|
"""Store canned `is_changed` responses keyed by node id."""
|
||||||
|
self._values = values
|
||||||
|
|
||||||
|
async def get(self, node_id):
|
||||||
|
"""Return the canned `is_changed` value for a node."""
|
||||||
|
return self._values[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class _OpaqueValue:
|
||||||
|
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def caching_module(monkeypatch):
|
||||||
|
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
||||||
|
torch_module = types.ModuleType("torch")
|
||||||
|
psutil_module = types.ModuleType("psutil")
|
||||||
|
nodes_module = types.ModuleType("nodes")
|
||||||
|
nodes_module.NODE_CLASS_MAPPINGS = {}
|
||||||
|
graph_module = types.ModuleType("comfy_execution.graph")
|
||||||
|
|
||||||
|
class DynamicPrompt:
|
||||||
|
"""Placeholder graph type so the caching module can import cleanly."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
graph_module.DynamicPrompt = DynamicPrompt
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "torch", torch_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
|
||||||
|
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
|
||||||
|
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
|
||||||
|
|
||||||
|
module = importlib.import_module("comfy_execution.caching")
|
||||||
|
module = importlib.reload(module)
|
||||||
|
return module, nodes_module
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||||
|
"""Shared built-in substructures should canonicalize without collapsing to Unhashable."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
shared = [{"value": 1}, {"value": 2}]
|
||||||
|
|
||||||
|
signature = caching._signature_to_hashable([shared, shared])
|
||||||
|
|
||||||
|
assert signature[0] == "list"
|
||||||
|
assert signature[1][0] == signature[1][1]
|
||||||
|
assert signature[1][0][0] == "list"
|
||||||
|
assert signature[1][0][1][0] == ("dict", (("value", 1),))
|
||||||
|
assert signature[1][0][1][1] == ("dict", (("value", 2),))
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module):
|
||||||
|
"""Opaque values should collapse the full signature to Unhashable immediately."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
signature = caching._signature_to_hashable(["safe", object()])
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch):
|
||||||
|
"""Once canonicalization fails, later recursive descent should stop immediately."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original = caching._signature_to_hashable_impl
|
||||||
|
marker = object()
|
||||||
|
marker_seen = False
|
||||||
|
|
||||||
|
def tracking_canonicalize(obj, *args, **kwargs):
|
||||||
|
"""Track whether recursion reaches the nested marker after failure."""
|
||||||
|
nonlocal marker_seen
|
||||||
|
if obj is marker:
|
||||||
|
marker_seen = True
|
||||||
|
return original(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize)
|
||||||
|
|
||||||
|
signature = caching._signature_to_hashable([object(), [marker]])
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
assert marker_seen is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch):
|
||||||
|
"""List canonicalization should read a point-in-time snapshot before recursive descent."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original = caching._signature_to_hashable_impl
|
||||||
|
marker = ("marker",)
|
||||||
|
values = [marker, 2]
|
||||||
|
|
||||||
|
def mutating_canonicalize(obj, *args, **kwargs):
|
||||||
|
"""Mutate the live list during recursion to verify snapshot-based traversal."""
|
||||||
|
if obj is marker:
|
||||||
|
values[1] = 3
|
||||||
|
return original(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
||||||
|
|
||||||
|
signature = caching._signature_to_hashable(values)
|
||||||
|
|
||||||
|
assert signature == ("list", (("tuple", ("marker",)), 2))
|
||||||
|
assert values[1] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch):
|
||||||
|
"""Dict canonicalization should read a point-in-time snapshot before recursive descent."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original = caching._signature_to_hashable_impl
|
||||||
|
marker = ("marker",)
|
||||||
|
values = {"first": marker, "second": 2}
|
||||||
|
|
||||||
|
def mutating_canonicalize(obj, *args, **kwargs):
|
||||||
|
"""Mutate the live dict during recursion to verify snapshot-based traversal."""
|
||||||
|
if obj is marker:
|
||||||
|
values["second"] = 3
|
||||||
|
return original(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
||||||
|
|
||||||
|
signature = caching._signature_to_hashable(values)
|
||||||
|
|
||||||
|
assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2)))
|
||||||
|
assert values["second"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"container_factory",
|
||||||
|
[
|
||||||
|
lambda marker: [marker],
|
||||||
|
lambda marker: (marker,),
|
||||||
|
lambda marker: {marker},
|
||||||
|
lambda marker: frozenset({marker}),
|
||||||
|
lambda marker: {"key": marker},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
||||||
|
"""Traversal RuntimeError should degrade canonicalization to Unhashable."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original = caching._signature_to_hashable_impl
|
||||||
|
marker = object()
|
||||||
|
|
||||||
|
def raising_canonicalize(obj, *args, **kwargs):
|
||||||
|
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
|
||||||
|
if obj is marker:
|
||||||
|
raise RuntimeError("container changed during iteration")
|
||||||
|
return original(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize)
|
||||||
|
|
||||||
|
signature = caching._signature_to_hashable(container_factory(marker))
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||||
|
"""The legacy helper should still hash sanitized built-ins stably when used directly."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
shared = [{"value": 1}, {"value": 2}]
|
||||||
|
|
||||||
|
sanitized = [shared, shared]
|
||||||
|
hashable = caching.to_hashable(sanitized)
|
||||||
|
|
||||||
|
assert hashable[0] == "list"
|
||||||
|
assert hashable[1][0] == hashable[1][1]
|
||||||
|
assert hashable[1][0][0] == "list"
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_uses_parent_snapshot_during_expanded_phase(caching_module, monkeypatch):
|
||||||
|
"""Expanded-phase assembly should not reread a live parent container after snapshotting."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original_sort_key = caching._sanitized_sort_key
|
||||||
|
outer = [{"marker"}, 2]
|
||||||
|
|
||||||
|
def mutating_sort_key(obj, *args, **kwargs):
|
||||||
|
"""Mutate the live parent while a child container is being canonicalized."""
|
||||||
|
if obj == "marker":
|
||||||
|
outer[1] = 3
|
||||||
|
return original_sort_key(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_sanitized_sort_key", mutating_sort_key)
|
||||||
|
|
||||||
|
hashable = caching.to_hashable(outer)
|
||||||
|
|
||||||
|
assert hashable == ("list", (("set", ("marker",)), 2))
|
||||||
|
assert outer[1] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_fails_closed_for_ordered_container_with_opaque_child(caching_module):
|
||||||
|
"""Ordered containers should fail closed when a child cannot be canonicalized."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
result = caching.to_hashable([object()])
|
||||||
|
|
||||||
|
assert isinstance(result, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_canonicalizes_dict_insertion_order(caching_module):
|
||||||
|
"""Dicts with the same content should hash identically regardless of insertion order."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
first = {"b": 2, "a": 1}
|
||||||
|
second = {"a": 1, "b": 2}
|
||||||
|
|
||||||
|
assert caching.to_hashable(first) == ("dict", (("a", 1), ("b", 2)))
|
||||||
|
assert caching.to_hashable(first) == caching.to_hashable(second)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_fails_closed_for_opaque_dict_key(caching_module):
|
||||||
|
"""Opaque dict keys should fail closed instead of being traversed during hashing."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
hashable = caching.to_hashable({_OpaqueValue(): 1})
|
||||||
|
|
||||||
|
assert isinstance(hashable, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"container_factory",
|
||||||
|
[
|
||||||
|
set,
|
||||||
|
frozenset,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
||||||
|
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
def raising_sort_key(obj, *args, **kwargs):
|
||||||
|
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
|
||||||
|
raise RuntimeError("container changed during iteration")
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
|
||||||
|
|
||||||
|
hashable = caching.to_hashable(container_factory({"value"}))
|
||||||
|
|
||||||
|
assert isinstance(hashable, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch):
|
||||||
|
"""Ambiguous dict key ordering should fail closed instead of using insertion order."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original_sort_key = caching._sanitized_sort_key
|
||||||
|
ambiguous = {"a": 1, "b": 1}
|
||||||
|
|
||||||
|
def colliding_sort_key(obj, *args, **kwargs):
|
||||||
|
"""Force two distinct primitive keys to share the same ordering key."""
|
||||||
|
if obj == "a" or obj == "b":
|
||||||
|
return ("COLLIDE",)
|
||||||
|
return original_sort_key(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key)
|
||||||
|
|
||||||
|
hashable = caching.to_hashable(ambiguous)
|
||||||
|
|
||||||
|
assert isinstance(hashable, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module, monkeypatch):
|
||||||
|
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original_sort_key = caching._primitive_signature_sort_key
|
||||||
|
ambiguous = {"a": 1, "b": 1}
|
||||||
|
|
||||||
|
def colliding_sort_key(obj):
|
||||||
|
"""Force two distinct primitive keys to share the same ordering key."""
|
||||||
|
if obj == "a" or obj == "b":
|
||||||
|
return ("COLLIDE",)
|
||||||
|
return original_sort_key(obj)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key)
|
||||||
|
|
||||||
|
sanitized = caching._signature_to_hashable(ambiguous)
|
||||||
|
|
||||||
|
assert isinstance(sanitized, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_fails_closed_for_opaque_dict_key(caching_module):
|
||||||
|
"""Opaque dict keys should fail closed instead of being recursively canonicalized."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
sanitized = caching._signature_to_hashable({_OpaqueValue(): 1})
|
||||||
|
|
||||||
|
assert isinstance(sanitized, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_to_hashable_fails_closed_on_dict_key_sort_collisions_even_with_distinct_values(caching_module, monkeypatch):
|
||||||
|
"""Different values must not mask dict key-sort collisions during canonicalization."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original_sort_key = caching._primitive_signature_sort_key
|
||||||
|
|
||||||
|
def colliding_sort_key(obj):
|
||||||
|
"""Force two distinct primitive keys to share the same ordering key."""
|
||||||
|
if obj == "a" or obj == "b":
|
||||||
|
return ("COLLIDE",)
|
||||||
|
return original_sort_key(obj)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_primitive_signature_sort_key", colliding_sort_key)
|
||||||
|
|
||||||
|
sanitized = caching._signature_to_hashable({"a": 1, "b": 2})
|
||||||
|
|
||||||
|
assert isinstance(sanitized, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"container_factory",
|
||||||
|
[
|
||||||
|
set,
|
||||||
|
frozenset,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, monkeypatch, container_factory):
|
||||||
|
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
original_sort_key = caching._sanitized_sort_key
|
||||||
|
container = container_factory({"a", "b"})
|
||||||
|
|
||||||
|
def colliding_sort_key(obj, *args, **kwargs):
|
||||||
|
"""Force two distinct primitive values to share the same ordering key."""
|
||||||
|
if obj == "a" or obj == "b":
|
||||||
|
return ("COLLIDE",)
|
||||||
|
return original_sort_key(obj, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(caching, "_sanitized_sort_key", colliding_sort_key)
|
||||||
|
|
||||||
|
hashable = caching.to_hashable(container)
|
||||||
|
|
||||||
|
assert isinstance(hashable, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
|
||||||
|
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
|
||||||
|
caching, nodes_module = caching_module
|
||||||
|
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
caching,
|
||||||
|
"to_hashable",
|
||||||
|
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
|
||||||
|
)
|
||||||
|
|
||||||
|
is_changed_value = []
|
||||||
|
is_changed_value.append(is_changed_value)
|
||||||
|
|
||||||
|
dynprompt = _FakeDynPrompt(
|
||||||
|
{
|
||||||
|
"node": {
|
||||||
|
"class_type": "UnitTestNode",
|
||||||
|
"inputs": {"value": 5},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
key_set = caching.CacheKeySetInputSignature(
|
||||||
|
dynprompt,
|
||||||
|
["node"],
|
||||||
|
_FakeIsChangedCache({"node": is_changed_value}),
|
||||||
|
)
|
||||||
|
|
||||||
|
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shallow_is_changed_signature_accepts_primitive_lists(caching_module):
|
||||||
|
"""Primitive-only `is_changed` lists should stay hashable without deep descent."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
sanitized = caching._shallow_is_changed_signature([1, "two", None, True])
|
||||||
|
|
||||||
|
assert sanitized == ("is_changed_list", (1, "two", None, True))
|
||||||
|
|
||||||
|
|
||||||
|
def test_shallow_is_changed_signature_accepts_structured_builtin_fingerprint_lists(caching_module):
|
||||||
|
"""Structured built-in `is_changed` fingerprints should remain representable."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
sanitized = caching._shallow_is_changed_signature([("seed", 42), {"cfg": 8}])
|
||||||
|
|
||||||
|
assert sanitized == (
|
||||||
|
"is_changed_list",
|
||||||
|
(
|
||||||
|
("tuple", ("seed", 42)),
|
||||||
|
("dict", (("cfg", 8),)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shallow_is_changed_signature_fails_closed_for_opaque_payload(caching_module):
|
||||||
|
"""Opaque `is_changed` payloads should still fail closed."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
|
||||||
|
sanitized = caching._shallow_is_changed_signature([_OpaqueValue()])
|
||||||
|
|
||||||
|
assert isinstance(sanitized, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_immediate_node_signature_fails_closed_for_unhashable_is_changed(caching_module, monkeypatch):
|
||||||
|
"""Recursive `is_changed` payloads should fail the full fragment closed."""
|
||||||
|
caching, nodes_module = caching_module
|
||||||
|
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||||
|
|
||||||
|
is_changed_value = []
|
||||||
|
is_changed_value.append(is_changed_value)
|
||||||
|
dynprompt = _FakeDynPrompt(
|
||||||
|
{
|
||||||
|
"node": {
|
||||||
|
"class_type": "UnitTestNode",
|
||||||
|
"inputs": {"value": 5},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
key_set = caching.CacheKeySetInputSignature(
|
||||||
|
dynprompt,
|
||||||
|
["node"],
|
||||||
|
_FakeIsChangedCache({"node": is_changed_value}),
|
||||||
|
)
|
||||||
|
|
||||||
|
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "node", {}))
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_immediate_node_signature_fails_closed_for_missing_node(caching_module):
|
||||||
|
"""Missing nodes should return the fail-closed sentinel instead of a NaN tuple."""
|
||||||
|
caching, _ = caching_module
|
||||||
|
dynprompt = _FakeDynPrompt({})
|
||||||
|
key_set = caching.CacheKeySetInputSignature(
|
||||||
|
dynprompt,
|
||||||
|
[],
|
||||||
|
_FakeIsChangedCache({}),
|
||||||
|
)
|
||||||
|
|
||||||
|
signature = asyncio.run(key_set.get_immediate_node_signature(dynprompt, "missing", {}))
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
198
tests/execution/test_caching.py
Normal file
198
tests/execution/test_caching.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from comfy_execution import caching
|
||||||
|
|
||||||
|
|
||||||
|
class _StubDynPrompt:
|
||||||
|
def __init__(self, nodes):
|
||||||
|
self._nodes = nodes
|
||||||
|
|
||||||
|
def has_node(self, node_id):
|
||||||
|
return node_id in self._nodes
|
||||||
|
|
||||||
|
def get_node(self, node_id):
|
||||||
|
return self._nodes[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class _StubIsChangedCache:
|
||||||
|
async def get(self, node_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _StubNode:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {"required": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_immediate_node_signature_canonicalizes_non_link_inputs(monkeypatch):
|
||||||
|
live_value = [1, {"nested": [2, 3]}]
|
||||||
|
dynprompt = _StubDynPrompt(
|
||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestCacheNode",
|
||||||
|
"inputs": {"value": live_value},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||||
|
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||||
|
|
||||||
|
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||||
|
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
|
||||||
|
|
||||||
|
assert signature == (
|
||||||
|
"TestCacheNode",
|
||||||
|
None,
|
||||||
|
("value", ("list", (1, ("dict", (("nested", ("list", (2, 3))),))))),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_hashable_walks_dicts_without_rebinding_traversal_stack():
|
||||||
|
live_value = {
|
||||||
|
"outer": {"nested": [2, 3]},
|
||||||
|
"items": [{"leaf": 4}],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert caching.to_hashable(live_value) == (
|
||||||
|
"dict",
|
||||||
|
(
|
||||||
|
("items", ("list", (("dict", (("leaf", 4),)),))),
|
||||||
|
("outer", ("dict", (("nested", ("list", (2, 3))),))),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_immediate_node_signature_fails_closed_for_opaque_non_link_input(monkeypatch):
|
||||||
|
class OpaqueRuntimeValue:
|
||||||
|
pass
|
||||||
|
|
||||||
|
live_value = OpaqueRuntimeValue()
|
||||||
|
dynprompt = _StubDynPrompt(
|
||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestCacheNode",
|
||||||
|
"inputs": {"value": live_value},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||||
|
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||||
|
|
||||||
|
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||||
|
signature = asyncio.run(keyset.get_immediate_node_signature(dynprompt, "1", {}))
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_signature_propagates_unhashable_immediate_fragment(monkeypatch):
|
||||||
|
class OpaqueRuntimeValue:
|
||||||
|
pass
|
||||||
|
|
||||||
|
dynprompt = _StubDynPrompt(
|
||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestCacheNode",
|
||||||
|
"inputs": {"value": OpaqueRuntimeValue()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||||
|
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||||
|
|
||||||
|
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||||
|
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||||
|
|
||||||
|
assert isinstance(signature, caching.Unhashable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_signature_never_visits_raw_non_link_input(monkeypatch):
|
||||||
|
live_value = [1, 2, 3]
|
||||||
|
dynprompt = _StubDynPrompt(
|
||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestCacheNode",
|
||||||
|
"inputs": {"value": live_value},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||||
|
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
caching,
|
||||||
|
"_signature_to_hashable",
|
||||||
|
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||||
|
AssertionError("outer signature canonicalizer should not run")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||||
|
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||||
|
|
||||||
|
assert isinstance(signature, tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_signature_keeps_deep_canonicalized_input_fragment(monkeypatch):
|
||||||
|
live_value = 1
|
||||||
|
for _ in range(8):
|
||||||
|
live_value = [live_value]
|
||||||
|
expected = caching.to_hashable(live_value)
|
||||||
|
|
||||||
|
dynprompt = _StubDynPrompt(
|
||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestCacheNode",
|
||||||
|
"inputs": {"value": live_value},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||||
|
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||||
|
|
||||||
|
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||||
|
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||||
|
|
||||||
|
assert isinstance(signature, tuple)
|
||||||
|
assert signature[0][2][0] == "value"
|
||||||
|
assert signature[0][2][1] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_node_signature_keeps_large_precanonicalized_fragment(monkeypatch):
|
||||||
|
live_value = object()
|
||||||
|
canonical_fragment = ("tuple", tuple(("list", (index, index + 1)) for index in range(256)))
|
||||||
|
dynprompt = _StubDynPrompt(
|
||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"class_type": "TestCacheNode",
|
||||||
|
"inputs": {"value": live_value},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setitem(caching.nodes.NODE_CLASS_MAPPINGS, "TestCacheNode", _StubNode)
|
||||||
|
monkeypatch.setattr(caching, "NODE_CLASS_CONTAINS_UNIQUE_ID", {})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
caching,
|
||||||
|
"to_hashable",
|
||||||
|
lambda value, max_nodes=caching._MAX_SIGNATURE_CONTAINER_VISITS: (
|
||||||
|
canonical_fragment if value is live_value else caching.Unhashable()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
caching,
|
||||||
|
"_signature_to_hashable",
|
||||||
|
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||||
|
AssertionError("outer signature canonicalizer should not run")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
keyset = caching.CacheKeySetInputSignature(dynprompt, [], _StubIsChangedCache())
|
||||||
|
signature = asyncio.run(keyset.get_node_signature(dynprompt, "1"))
|
||||||
|
|
||||||
|
assert isinstance(signature, tuple)
|
||||||
|
assert signature[0][2] == ("value", canonical_fragment)
|
||||||
Loading…
Reference in New Issue
Block a user