diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index a3f3ac338..751c76e9f 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -123,9 +123,7 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): if node_id not in self.keys: return self.updated_node_ids.add(node_id) - node = self.dynprompt.get_node(node_id) self.keys[node_id] = await self.get_node_signature(node_id) - self.subcache_keys[node_id] = (node_id, node["class_type"]) def is_key_updated(self, node_id): return node_id in self.updated_node_ids @@ -149,16 +147,17 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): async def get_node_signature(self, node_id): signatures = [] - ancestors, order_mapping, input_hashes = self.get_ordered_ancestry(node_id) - self.node_sig_cache[node_id] = await self.get_immediate_node_signature(node_id, order_mapping, input_hashes) + ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id) + self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs)) signatures.append(self.node_sig_cache[node_id]) for ancestor_id in ancestors: assert ancestor_id in self.node_sig_cache signatures.append(self.node_sig_cache[ancestor_id]) - + + signatures = frozenset(zip(itertools.count(), signatures)) logging.debug(f"signature for node {node_id}: {signatures}") - return to_hashable(signatures) + return signatures async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict): if not self.dynprompt.has_node(node_id): @@ -192,14 +191,14 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): get_ancestors(self.ancestry_cache[ancestor_id], ret) return ret - ancestors, input_hashes = self.get_ordered_ancestry_internal(node_id) + ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id) ancestors = get_ancestors(ancestors) order_mapping = {} for i, ancestor_id in enumerate(ancestors): order_mapping[ancestor_id] = i - return ancestors, order_mapping, input_hashes + return ancestors, order_mapping, node_inputs def get_ordered_ancestry_internal(self, node_id): def get_hashable(obj): @@ -207,14 +206,15 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): return throw_on_unhashable(obj) except: return Unhashable + ancestors = [] - input_hashes = {} + node_inputs = {} if node_id in self.ancestry_cache: - return self.ancestry_cache[node_id], input_hashes + return self.ancestry_cache[node_id], node_inputs if not self.dynprompt.has_node(node_id): - return + return ancestors, node_inputs input_data_all, _, _ = self.is_changed.get_input_data(node_id) inputs = self.dynprompt.get_node(node_id)["inputs"] @@ -225,22 +225,22 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): hashable = get_hashable(input_data_all[key]) if hashable is Unhashable or is_link(input_data_all[key]): # Link still needed - input_hashes[key] = inputs[key] + node_inputs[key] = inputs[key] if ancestor_id not in ancestors: ancestors.append(ancestor_id) else: - # Replace link with input's hash - input_hashes[key] = hash(hashable) + # Replace link + node_inputs[key] = input_data_all[key] else: hashable = get_hashable(inputs[key]) if hashable is Unhashable: logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}") - input_hashes[key] = Unhashable() + node_inputs[key] = Unhashable() else: - input_hashes[key] = hash(hashable) + node_inputs[key] = inputs[key] self.ancestry_cache[node_id] = ancestors - return self.ancestry_cache[node_id], input_hashes + return self.ancestry_cache[node_id], node_inputs class BasicCache: def __init__(self, key_class):