mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 10:10:20 +08:00
Revert "Included original cache key set for testing"
This reverts commit f511703343.
This commit is contained in:
parent
f511703343
commit
af4d691d1f
@ -97,111 +97,6 @@ class CacheKeySetID(CacheKeySet):
|
||||
self.keys[node_id] = (node_id, node["class_type"])
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
class CacheKeySetInputSignatureOriginalConstant(CacheKeySet):
|
||||
"""Original CacheKeySet"""
|
||||
def __init__(self, dynprompt, node_ids, is_changed):
|
||||
super().__init__(dynprompt, node_ids, is_changed)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed = is_changed
|
||||
self.clean_when = "before"
|
||||
self.node_sig_cache = {}
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
async def get_node_signature(self, node_id):
|
||||
signatures = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(node_id)
|
||||
if node_id not in self.node_sig_cache:
|
||||
self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping))
|
||||
signatures.append(self.node_sig_cache[node_id])
|
||||
for ancestor_id in ancestors:
|
||||
if ancestor_id not in self.node_sig_cache:
|
||||
self.node_sig_cache[ancestor_id] = to_hashable(await self.get_immediate_node_signature(ancestor_id, order_mapping))
|
||||
signatures.append(self.node_sig_cache[ancestor_id])
|
||||
signatures = frozenset(zip(itertools.count(), signatures))
|
||||
logging.debug(f"signature for node {node_id}: {signatures}")
|
||||
return signatures
|
||||
|
||||
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping):
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, await self.is_changed.get(node_id)]
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
(ancestor_id, ancestor_socket) = inputs[key]
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||
signature.append(node_id)
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, node_id):
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, node_id, ancestors, order_mapping):
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
return
|
||||
inputs = self.dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
if ancestor_id not in order_mapping:
|
||||
ancestors.append(ancestor_id)
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class CacheKeySetInputSignatureOriginalUpdatable(CacheKeySetInputSignatureOriginalConstant):
|
||||
"""Original constant CacheKeySet modified to be updatable."""
|
||||
def __init__(self, dynprompt, node_ids, is_changed):
|
||||
super().__init__(dynprompt, node_ids, is_changed)
|
||||
self.clean_when = "after"
|
||||
self.updated_node_ids = set()
|
||||
|
||||
async def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = None
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
async def update_cache_key(self, node_id):
|
||||
if node_id in self.updated_node_ids:
|
||||
return
|
||||
if node_id not in self.keys:
|
||||
return
|
||||
self.updated_node_ids.add(node_id)
|
||||
self.keys[node_id] = await self.get_node_signature(node_id)
|
||||
|
||||
def is_key_updated(self, node_id):
|
||||
return node_id in self.updated_node_ids
|
||||
|
||||
class CacheKeySetUpdatableInputSignature(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed):
|
||||
super().__init__(dynprompt, node_ids, is_changed)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user