diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index b0591c08e..acba8e90a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -101,8 +101,6 @@ class CacheKeySetInputSignature(CacheKeySet): self.dynprompt = dynprompt self.is_changed = is_changed self.updated_node_ids = set() - self.node_sig_cache = {} - self.ancestry_cache = {} def include_node_id_in_input(self) -> bool: return False @@ -131,11 +129,13 @@ class CacheKeySetInputSignature(CacheKeySet): async def get_node_signature(self, node_id): signatures = [] ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id) - self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs)) - signatures.append(self.node_sig_cache[node_id]) + node = self.dynprompt.get_node(node_id) + node["signature"] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs)) + signatures.append(node["signature"]) for ancestor_id in ancestors: - assert ancestor_id in self.node_sig_cache - signatures.append(self.node_sig_cache[ancestor_id]) + ancestor_node = self.dynprompt.get_node(ancestor_id) + assert "signature" in ancestor_node + signatures.append(ancestor_node["signature"]) signatures = frozenset(zip(itertools.count(), signatures)) return signatures @@ -163,7 +163,8 @@ class CacheKeySetInputSignature(CacheKeySet): for ancestor_id in ancestors: if ancestor_id not in ret: ret.append(ancestor_id) - get_ancestors(self.ancestry_cache[ancestor_id], ret) + ancestor_node = self.dynprompt.get_node(ancestor_id) + get_ancestors(ancestor_node["ancestors"], ret) return ret ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id) @@ -185,12 +186,13 @@ class CacheKeySetInputSignature(CacheKeySet): ancestors = [] node_inputs = {} - if node_id in self.ancestry_cache: - return self.ancestry_cache[node_id], node_inputs - if not self.dynprompt.has_node(node_id): return ancestors, node_inputs + node = self.dynprompt.get_node(node_id) + if "ancestors" in node: + return node["ancestors"], node_inputs + input_data_all, _, _ = self.is_changed.get_input_data(node_id) inputs = self.dynprompt.get_node(node_id)["inputs"] for key in sorted(inputs.keys()): @@ -213,8 +215,8 @@ class CacheKeySetInputSignature(CacheKeySet): else: node_inputs[key] = [inputs[key]] - self.ancestry_cache[node_id] = ancestors - return self.ancestry_cache[node_id], node_inputs + node["ancestors"] = ancestors + return node["ancestors"], node_inputs class BasicCache: def __init__(self, key_class): @@ -224,14 +226,10 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} - self.node_sig_cache = {} - self.ancestry_cache = {} async def set_prompt(self, dynprompt, node_ids, is_changed): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed) - self.cache_key_set.node_sig_cache = self.node_sig_cache - self.cache_key_set.ancestry_cache = self.ancestry_cache await self.cache_key_set.add_keys(node_ids) self.is_changed = is_changed self.initialized = True @@ -266,8 +264,6 @@ class BasicCache: def clean_unused(self): assert self.initialized - self.node_sig_cache.clear() - self.ancestry_cache.clear() self._clean_cache() self._clean_subcaches() @@ -300,8 +296,6 @@ class BasicCache: subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) - subcache.node_sig_cache = self.node_sig_cache - subcache.ancestry_cache = self.ancestry_cache self.subcaches[subcache_key] = subcache await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed) return subcache @@ -415,8 +409,6 @@ class LRUCache(BasicCache): self._mark_used(node_id) def clean_unused(self): - self.node_sig_cache.clear() - self.ancestry_cache.clear() while len(self.cache) > self.max_size and self.min_generation < self.generation: self.min_generation += 1 to_remove = [key for key in self.cache if key not in self.used_generation or self.used_generation[key] < self.min_generation] @@ -484,8 +476,6 @@ class RAMPressureCache(LRUCache): self.timestamps = {} def clean_unused(self): - self.node_sig_cache.clear() - self.ancestry_cache.clear() self._clean_subcaches() def set(self, node_id, value):