Update caching.py

This commit is contained in:
Deluxe233 2026-01-26 09:33:00 -05:00
parent 38ab4e3c76
commit 4683136740

View File

@ -123,9 +123,7 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet):
if node_id not in self.keys: if node_id not in self.keys:
return return
self.updated_node_ids.add(node_id) self.updated_node_ids.add(node_id)
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = await self.get_node_signature(node_id) self.keys[node_id] = await self.get_node_signature(node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def is_key_updated(self, node_id): def is_key_updated(self, node_id):
return node_id in self.updated_node_ids return node_id in self.updated_node_ids
@ -149,16 +147,17 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet):
async def get_node_signature(self, node_id): async def get_node_signature(self, node_id):
signatures = [] signatures = []
ancestors, order_mapping, input_hashes = self.get_ordered_ancestry(node_id) ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id)
self.node_sig_cache[node_id] = await self.get_immediate_node_signature(node_id, order_mapping, input_hashes) self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs))
signatures.append(self.node_sig_cache[node_id]) signatures.append(self.node_sig_cache[node_id])
for ancestor_id in ancestors: for ancestor_id in ancestors:
assert ancestor_id in self.node_sig_cache assert ancestor_id in self.node_sig_cache
signatures.append(self.node_sig_cache[ancestor_id]) signatures.append(self.node_sig_cache[ancestor_id])
signatures = frozenset(zip(itertools.count(), signatures))
logging.debug(f"signature for node {node_id}: {signatures}") logging.debug(f"signature for node {node_id}: {signatures}")
return to_hashable(signatures) return signatures
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict): async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict):
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
@ -192,14 +191,14 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet):
get_ancestors(self.ancestry_cache[ancestor_id], ret) get_ancestors(self.ancestry_cache[ancestor_id], ret)
return ret return ret
ancestors, input_hashes = self.get_ordered_ancestry_internal(node_id) ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id)
ancestors = get_ancestors(ancestors) ancestors = get_ancestors(ancestors)
order_mapping = {} order_mapping = {}
for i, ancestor_id in enumerate(ancestors): for i, ancestor_id in enumerate(ancestors):
order_mapping[ancestor_id] = i order_mapping[ancestor_id] = i
return ancestors, order_mapping, input_hashes return ancestors, order_mapping, node_inputs
def get_ordered_ancestry_internal(self, node_id): def get_ordered_ancestry_internal(self, node_id):
def get_hashable(obj): def get_hashable(obj):
@ -207,14 +206,15 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet):
return throw_on_unhashable(obj) return throw_on_unhashable(obj)
except: except:
return Unhashable return Unhashable
ancestors = [] ancestors = []
input_hashes = {} node_inputs = {}
if node_id in self.ancestry_cache: if node_id in self.ancestry_cache:
return self.ancestry_cache[node_id], input_hashes return self.ancestry_cache[node_id], node_inputs
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
return return ancestors, node_inputs
input_data_all, _, _ = self.is_changed.get_input_data(node_id) input_data_all, _, _ = self.is_changed.get_input_data(node_id)
inputs = self.dynprompt.get_node(node_id)["inputs"] inputs = self.dynprompt.get_node(node_id)["inputs"]
@ -225,22 +225,22 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet):
hashable = get_hashable(input_data_all[key]) hashable = get_hashable(input_data_all[key])
if hashable is Unhashable or is_link(input_data_all[key]): if hashable is Unhashable or is_link(input_data_all[key]):
# Link still needed # Link still needed
input_hashes[key] = inputs[key] node_inputs[key] = inputs[key]
if ancestor_id not in ancestors: if ancestor_id not in ancestors:
ancestors.append(ancestor_id) ancestors.append(ancestor_id)
else: else:
# Replace link with input's hash # Replace link
input_hashes[key] = hash(hashable) node_inputs[key] = input_data_all[key]
else: else:
hashable = get_hashable(inputs[key]) hashable = get_hashable(inputs[key])
if hashable is Unhashable: if hashable is Unhashable:
logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}") logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}")
input_hashes[key] = Unhashable() node_inputs[key] = Unhashable()
else: else:
input_hashes[key] = hash(hashable) node_inputs[key] = inputs[key]
self.ancestry_cache[node_id] = ancestors self.ancestry_cache[node_id] = ancestors
return self.ancestry_cache[node_id], input_hashes return self.ancestry_cache[node_id], node_inputs
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class):