From f511703343072ad1a2b4029c267b36f7f4addec7 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Tue, 27 Jan 2026 13:51:41 -0500 Subject: [PATCH] Included original cache key set for testing --- comfy_execution/caching.py | 105 +++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 8fcd19b00..375d6a36b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -97,6 +97,111 @@ class CacheKeySetID(CacheKeySet): self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) +class CacheKeySetInputSignatureOriginalConstant(CacheKeySet): + """Original CacheKeySet""" + def __init__(self, dynprompt, node_ids, is_changed): + super().__init__(dynprompt, node_ids, is_changed) + self.dynprompt = dynprompt + self.is_changed = is_changed + self.clean_when = "before" + self.node_sig_cache = {} + + def include_node_id_in_input(self) -> bool: + return False + + async def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + if not self.dynprompt.has_node(node_id): + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = await self.get_node_signature(node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + async def get_node_signature(self, node_id): + signatures = [] + ancestors, order_mapping = self.get_ordered_ancestry(node_id) + if node_id not in self.node_sig_cache: + self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping)) + signatures.append(self.node_sig_cache[node_id]) + for ancestor_id in ancestors: + if ancestor_id not in self.node_sig_cache: + self.node_sig_cache[ancestor_id] = to_hashable(await self.get_immediate_node_signature(ancestor_id, order_mapping)) + signatures.append(self.node_sig_cache[ancestor_id]) + signatures = frozenset(zip(itertools.count(), signatures)) + logging.debug(f"signature for node {node_id}: {signatures}") + return signatures + + async def get_immediate_node_signature(self, node_id, ancestor_order_mapping): + if not self.dynprompt.has_node(node_id): + # This node doesn't exist -- we can't cache it. + return [float("NaN")] + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, await self.is_changed.get(node_id)] + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): + signature.append(node_id) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, node_id, ancestors, order_mapping): + if not self.dynprompt.has_node(node_id): + return + inputs = self.dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(ancestor_id, ancestors, order_mapping) + +class CacheKeySetInputSignatureOriginalUpdatable(CacheKeySetInputSignatureOriginalConstant): + """Original constant CacheKeySet modified to be updatable.""" + def __init__(self, dynprompt, node_ids, is_changed): + super().__init__(dynprompt, node_ids, is_changed) + self.clean_when = "after" + self.updated_node_ids = set() + + async def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + if not self.dynprompt.has_node(node_id): + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = None + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + async def update_cache_key(self, node_id): + if node_id in self.updated_node_ids: + return + if node_id not in self.keys: + return + self.updated_node_ids.add(node_id) + self.keys[node_id] = await self.get_node_signature(node_id) + + def is_key_updated(self, node_id): + return node_id in self.updated_node_ids + class CacheKeySetUpdatableInputSignature(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed)