diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 78212bde3..cbf2e9de1 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,7 +1,6 @@ import asyncio import bisect import gc -import itertools import psutil import time import torch @@ -17,6 +16,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} def include_unique_id_in_input(class_type: str) -> bool: + """Return whether a node class includes UNIQUE_ID among its hidden inputs.""" if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -24,52 +24,181 @@ def include_unique_id_in_input(class_type: str) -> bool: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class CacheKeySet(ABC): + """Base helper for building and storing cache keys for prompt nodes.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize cache-key storage for a dynamic prompt execution pass.""" self.keys = {} self.subcache_keys = {} @abstractmethod async def add_keys(self, node_ids): + """Populate cache keys for the provided node ids.""" raise NotImplementedError() def all_node_ids(self): + """Return the set of node ids currently tracked by this key set.""" return set(self.keys.keys()) def get_used_keys(self): + """Return the computed cache keys currently in use.""" return self.keys.values() def get_used_subcache_keys(self): + """Return the computed subcache keys currently in use.""" return self.subcache_keys.values() def get_data_key(self, node_id): + """Return the cache key for a node, if present.""" return self.keys.get(node_id, None) def get_subcache_key(self, node_id): + """Return the subcache key for a node, if present.""" return self.subcache_keys.get(node_id, None) class Unhashable: - def __init__(self): - self.value = float("NaN") + """Hashable identity sentinel for values that cannot be represented safely in cache keys.""" + pass -def to_hashable(obj): - # So that we don't infinitely recurse since frozenset and tuples - # are Sequences. - if isinstance(obj, (int, float, str, bool, bytes, type(None))): - return obj - elif isinstance(obj, Mapping): - return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) - elif isinstance(obj, Sequence): - return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + +def _sanitized_sort_key(obj, depth=0, max_depth=32): + """Return a deterministic ordering key for sanitized built-in container content.""" + if depth >= max_depth: + return ("MAX_DEPTH",) + + obj_type = type(obj) + if obj_type is Unhashable: + return ("UNHASHABLE",) + elif obj_type in (int, float, str, bool, bytes, type(None)): + return (obj_type.__module__, obj_type.__qualname__, repr(obj)) + elif obj_type is dict: + items = [ + ( + _sanitized_sort_key(k, depth + 1, max_depth), + _sanitized_sort_key(v, depth + 1, max_depth), + ) + for k, v in obj.items() + ] + items.sort() + return ("dict", tuple(items)) + elif obj_type is list: + return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)) + elif obj_type is tuple: + return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)) + elif obj_type is set: + return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))) + elif obj_type is frozenset: + return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))) + else: + return (obj_type.__module__, obj_type.__qualname__, "OPAQUE") + + +def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None): + """Normalize signature inputs to safe built-in containers. + + Preserves built-in container type, replaces opaque runtime values with + Unhashable(), and stops safely on cycles or excessive depth. + """ + if depth >= max_depth: + return Unhashable() + + if seen is None: + seen = set() + + obj_type = type(obj) + if obj_type in (dict, list, tuple, set, frozenset): + obj_id = id(obj) + if obj_id in seen: + return Unhashable() + next_seen = seen | {obj_id} + + if obj_type in (int, float, str, bool, bytes, type(None)): + return obj + elif obj_type is dict: + sanitized_items = [ + ( + _sanitize_signature_input(key, depth + 1, max_depth, next_seen), + _sanitize_signature_input(value, depth + 1, max_depth, next_seen), + ) + for key, value in obj.items() + ] + sanitized_items.sort( + key=lambda kv: ( + _sanitized_sort_key(kv[0], depth + 1, max_depth), + _sanitized_sort_key(kv[1], depth + 1, max_depth), + ) + ) + return {key: value for key, value in sanitized_items} + elif obj_type is list: + return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj] + elif obj_type is tuple: + return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj) + elif obj_type is set: + return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj} + elif obj_type is frozenset: + return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj) + else: + # Execution-cache signatures should be built from prompt-safe values. + # If a custom node injects a runtime object here, mark it unhashable so + # the node won't reuse stale cache entries across runs, but do not walk + # the foreign object and risk crashing on custom container semantics. + return Unhashable() + +def to_hashable(obj, depth=0, max_depth=32, seen=None): + """Convert sanitized prompt inputs into a stable hashable representation. + + Preserves built-in container type and stops safely on cycles or excessive depth. + """ + if depth >= max_depth: + return Unhashable() + + if seen is None: + seen = set() + + # Restrict recursion to plain built-in containers. Some custom nodes insert + # runtime objects into prompt inputs for dynamic graph paths; walking those + # objects as generic Mappings / Sequences is unsafe and can destabilize the + # cache signature builder. + obj_type = type(obj) + if obj_type in (int, float, str, bool, bytes, type(None)): + return obj + + if obj_type in (dict, list, tuple, set, frozenset): + obj_id = id(obj) + if obj_id in seen: + return Unhashable() + seen = seen | {obj_id} + + if obj_type is dict: + return ( + "dict", + frozenset( + ( + to_hashable(k, depth + 1, max_depth, seen), + to_hashable(v, depth + 1, max_depth, seen), + ) + for k, v in obj.items() + ), + ) + elif obj_type is list: + return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) + elif obj_type is tuple: + return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) + elif obj_type is set: + return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) + elif obj_type is frozenset: + return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj)) else: - # TODO - Support other objects like tensors? return Unhashable() class CacheKeySetID(CacheKeySet): + """Cache-key strategy that keys nodes by node id and class type.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize identity-based cache keys for the supplied dynamic prompt.""" super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt async def add_keys(self, node_ids): + """Populate identity-based keys for nodes that exist in the dynamic prompt.""" for node_id in node_ids: if node_id in self.keys: continue @@ -80,15 +209,19 @@ class CacheKeySetID(CacheKeySet): self.subcache_keys[node_id] = (node_id, node["class_type"]) class CacheKeySetInputSignature(CacheKeySet): + """Cache-key strategy that hashes a node's immediate inputs plus ancestor references.""" def __init__(self, dynprompt, node_ids, is_changed_cache): + """Initialize input-signature-based cache keys for the supplied dynamic prompt.""" super().__init__(dynprompt, node_ids, is_changed_cache) self.dynprompt = dynprompt self.is_changed_cache = is_changed_cache def include_node_id_in_input(self) -> bool: + """Return whether node ids should be included in computed input signatures.""" return False async def add_keys(self, node_ids): + """Populate input-signature-based keys for nodes in the dynamic prompt.""" for node_id in node_ids: if node_id in self.keys: continue @@ -99,6 +232,7 @@ class CacheKeySetInputSignature(CacheKeySet): self.subcache_keys[node_id] = (node_id, node["class_type"]) async def get_node_signature(self, dynprompt, node_id): + """Build the full cache signature for a node and its ordered ancestors.""" signature = [] ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) @@ -107,6 +241,10 @@ class CacheKeySetInputSignature(CacheKeySet): return to_hashable(signature) async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + """Build the cache-signature fragment for a node's immediate inputs. + + Link inputs are reduced to ancestor references, while raw values are sanitized first. + """ if not dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] @@ -123,18 +261,20 @@ class CacheKeySetInputSignature(CacheKeySet): ancestor_index = ancestor_order_mapping[ancestor_id] signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: - signature.append((key, inputs[key])) + signature.append((key, _sanitize_signature_input(inputs[key]))) return signature # This function returns a list of all ancestors of the given node. The order of the list is # deterministic based on which specific inputs the ancestor is connected by. def get_ordered_ancestry(self, dynprompt, node_id): + """Return ancestors in deterministic traversal order and their index mapping.""" ancestors = [] order_mapping = {} self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) return ancestors, order_mapping def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + """Recursively collect ancestors in input order without revisiting prior nodes.""" if not dynprompt.has_node(node_id): return inputs = dynprompt.get_node(node_id)["inputs"] diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py index 496d2c634..57b4ef36e 100644 --- a/comfy_execution/graph_utils.py +++ b/comfy_execution/graph_utils.py @@ -1,11 +1,17 @@ def is_link(obj): - if not isinstance(obj, list): + """Return whether obj is a plain prompt link of the form [node_id, output_index].""" + # Prompt links produced by the frontend / GraphBuilder are plain Python + # lists in the form [node_id, output_index]. Some custom-node paths can + # inject foreign runtime objects into prompt inputs during on-prompt graph + # rewriting or subgraph construction. Be strict here so cache signature + # building never tries to treat list-like proxy objects as links. + if type(obj) is not list: return False if len(obj) != 2: return False - if not isinstance(obj[0], str): + if type(obj[0]) is not str: return False - if not isinstance(obj[1], int) and not isinstance(obj[1], float): + if type(obj[1]) is not int: return False return True