mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-08 20:42:32 +08:00
Added a new type of cache key set.
This commit is contained in:
parent
7ee77ff038
commit
232995856e
@ -3,6 +3,7 @@ import gc
|
|||||||
import itertools
|
import itertools
|
||||||
import psutil
|
import psutil
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from typing import Sequence, Mapping, Dict
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
@ -14,7 +15,6 @@ from comfy_execution.graph_utils import is_link
|
|||||||
|
|
||||||
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
def include_unique_id_in_input(class_type: str) -> bool:
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
@ -23,9 +23,10 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
|||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet(ABC):
|
class CacheKeySet(ABC):
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.subcache_keys = {}
|
self.subcache_keys = {}
|
||||||
|
self.clean_when = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
@ -46,6 +47,15 @@ class CacheKeySet(ABC):
|
|||||||
def get_subcache_key(self, node_id):
|
def get_subcache_key(self, node_id):
|
||||||
return self.subcache_keys.get(node_id, None)
|
return self.subcache_keys.get(node_id, None)
|
||||||
|
|
||||||
|
async def update_cache_key(self, node_id) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_key_updated(self, node_id) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_key_updatable(self, node_id) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
class Unhashable:
|
class Unhashable:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.value = float("NaN")
|
self.value = float("NaN")
|
||||||
@ -63,10 +73,22 @@ def to_hashable(obj):
|
|||||||
# TODO - Support other objects like tensors?
|
# TODO - Support other objects like tensors?
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
|
def throw_on_unhashable(obj):
|
||||||
|
# Same as to_hashable except throwing for unhashables instead.
|
||||||
|
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||||
|
return obj
|
||||||
|
elif isinstance(obj, Mapping):
|
||||||
|
return frozenset([(throw_on_unhashable(k), throw_on_unhashable(v)) for k, v in sorted(obj.items())])
|
||||||
|
elif isinstance(obj, Sequence):
|
||||||
|
return frozenset(zip(itertools.count(), [throw_on_unhashable(i) for i in obj]))
|
||||||
|
else:
|
||||||
|
raise Exception("Object unhashable.")
|
||||||
|
|
||||||
class CacheKeySetID(CacheKeySet):
|
class CacheKeySetID(CacheKeySet):
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed):
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
|
self.clean_when = "before"
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
@ -78,73 +100,143 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
self.keys[node_id] = (node_id, node["class_type"])
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
class CacheKeySetInputSignature(CacheKeySet):
|
class CacheKeySetUpdatableInputSignature(CacheKeySet):
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed):
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt: DynamicPrompt = dynprompt
|
||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed = is_changed
|
||||||
|
self.clean_when = "after"
|
||||||
|
|
||||||
|
self.updated_node_ids = set()
|
||||||
|
self.node_sig_cache = {}
|
||||||
|
"""Nodes' immediate node signatures."""
|
||||||
|
self.ancestry_cache = {}
|
||||||
|
"""List of a node's ancestors."""
|
||||||
|
|
||||||
def include_node_id_in_input(self) -> bool:
|
def include_node_id_in_input(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def update_cache_key(self, node_id):
|
||||||
|
"""Update key using cached outputs as part of the input signature."""
|
||||||
|
if node_id in self.updated_node_ids:
|
||||||
|
return
|
||||||
|
if node_id not in self.keys:
|
||||||
|
return
|
||||||
|
self.updated_node_ids.add(node_id)
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
self.keys[node_id] = await self.get_node_signature(node_id)
|
||||||
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
|
def is_key_updated(self, node_id):
|
||||||
|
return node_id in self.updated_node_ids
|
||||||
|
|
||||||
|
def is_key_updatable(self, node_id):
|
||||||
|
_, missing_keys, _ = self.is_changed.get_input_data(node_id)
|
||||||
|
if missing_keys:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Initialize keys."""
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
if not self.dynprompt.has_node(node_id):
|
if not self.dynprompt.has_node(node_id):
|
||||||
continue
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id)
|
self.keys[node_id] = None
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
async def get_node_signature(self, dynprompt, node_id):
|
async def get_node_signature(self, node_id):
|
||||||
signature = []
|
signatures = []
|
||||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
ancestors, order_mapping, input_hashes = self.get_ordered_ancestry(node_id)
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
self.node_sig_cache[node_id] = await self.get_immediate_node_signature(node_id, order_mapping, input_hashes)
|
||||||
for ancestor_id in ancestors:
|
signatures.append(self.node_sig_cache[node_id])
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
|
||||||
return to_hashable(signature)
|
|
||||||
|
|
||||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
for ancestor_id in ancestors:
|
||||||
if not dynprompt.has_node(node_id):
|
assert ancestor_id in self.node_sig_cache
|
||||||
|
signatures.append(self.node_sig_cache[ancestor_id])
|
||||||
|
|
||||||
|
logging.debug(f"signature for {node_id}:\n{signatures}")
|
||||||
|
return to_hashable(signatures)
|
||||||
|
|
||||||
|
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict):
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
# This node doesn't exist -- we can't cache it.
|
# This node doesn't exist -- we can't cache it.
|
||||||
return [float("NaN")]
|
return [float("NaN")]
|
||||||
node = dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
signature = [class_type, await self.is_changed_cache.get(node_id)]
|
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
signature = [class_type, await self.is_changed.get(node_id)]
|
||||||
signature.append(node_id)
|
|
||||||
inputs = node["inputs"]
|
|
||||||
for key in sorted(inputs.keys()):
|
for key in sorted(inputs.keys()):
|
||||||
if is_link(inputs[key]):
|
input = inputs[key]
|
||||||
(ancestor_id, ancestor_socket) = inputs[key]
|
if is_link(input):
|
||||||
|
(ancestor_id, ancestor_socket) = input
|
||||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||||
else:
|
else:
|
||||||
signature.append((key, inputs[key]))
|
signature.append((key, input))
|
||||||
|
|
||||||
|
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
|
||||||
|
signature.append(node_id)
|
||||||
|
|
||||||
return signature
|
return signature
|
||||||
|
|
||||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
def get_ordered_ancestry(self, node_id):
|
||||||
# deterministic based on which specific inputs the ancestor is connected by.
|
def get_ancestors(ancestors, ret: list=[]):
|
||||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
for ancestor_id in ancestors:
|
||||||
ancestors = []
|
if ancestor_id not in ret:
|
||||||
order_mapping = {}
|
ret.append(ancestor_id)
|
||||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
get_ancestors(self.ancestry_cache[ancestor_id], ret)
|
||||||
return ancestors, order_mapping
|
return ret
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
ancestors, input_hashes = self.get_ordered_ancestry_internal(node_id)
|
||||||
if not dynprompt.has_node(node_id):
|
ancestors = get_ancestors(ancestors)
|
||||||
|
|
||||||
|
order_mapping = {}
|
||||||
|
for i, ancestor_id in enumerate(ancestors):
|
||||||
|
order_mapping[ancestor_id] = i
|
||||||
|
|
||||||
|
return ancestors, order_mapping, input_hashes
|
||||||
|
|
||||||
|
def get_ordered_ancestry_internal(self, node_id):
|
||||||
|
ancestors = []
|
||||||
|
input_hashes = {}
|
||||||
|
|
||||||
|
if node_id in self.ancestry_cache:
|
||||||
|
return self.ancestry_cache[node_id], input_hashes
|
||||||
|
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
return
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
|
||||||
input_keys = sorted(inputs.keys())
|
input_data_all, _, _ = self.is_changed.get_input_data(node_id)
|
||||||
for key in input_keys:
|
inputs = self.dynprompt.get_node(node_id)["inputs"]
|
||||||
if is_link(inputs[key]):
|
for key in sorted(inputs.keys()):
|
||||||
ancestor_id = inputs[key][0]
|
input = inputs[key]
|
||||||
if ancestor_id not in order_mapping:
|
if key in input_data_all:
|
||||||
ancestors.append(ancestor_id)
|
if is_link(input):
|
||||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
ancestor_id = input[0]
|
||||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
try:
|
||||||
|
# Replace link with input's hash
|
||||||
|
hashable = throw_on_unhashable(input_data_all[key])
|
||||||
|
input_hashes[key] = hash(hashable)
|
||||||
|
except:
|
||||||
|
# Link still needed
|
||||||
|
input_hashes[key] = input
|
||||||
|
if ancestor_id not in ancestors:
|
||||||
|
ancestors.append(ancestor_id)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
hashable = throw_on_unhashable(input)
|
||||||
|
input_hashes[key] = hash(hashable)
|
||||||
|
except:
|
||||||
|
logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {input}")
|
||||||
|
input_hashes[key] = Unhashable()
|
||||||
|
|
||||||
|
self.ancestry_cache[node_id] = ancestors
|
||||||
|
return self.ancestry_cache[node_id], input_hashes
|
||||||
|
|
||||||
class BasicCache:
|
class BasicCache:
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class):
|
||||||
@ -154,12 +246,14 @@ class BasicCache:
|
|||||||
self.cache_key_set: CacheKeySet
|
self.cache_key_set: CacheKeySet
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
|
self.clean_when = "before"
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
|
||||||
await self.cache_key_set.add_keys(node_ids)
|
await self.cache_key_set.add_keys(node_ids)
|
||||||
self.is_changed_cache = is_changed_cache
|
self.clean_when = self.cache_key_set.clean_when or "before"
|
||||||
|
self.is_changed = is_changed
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
@ -196,16 +290,29 @@ class BasicCache:
|
|||||||
def poll(self, **kwargs):
|
def poll(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _update_cache_key_immediate(self, node_id):
|
||||||
|
"""Update the cache key for the node."""
|
||||||
|
await self.cache_key_set.update_cache_key(node_id)
|
||||||
|
|
||||||
|
def _is_key_updated_immediate(self, node_id):
|
||||||
|
"""False if the cache key set is an updatable type and it hasn't been updated yet."""
|
||||||
|
return self.cache_key_set.is_key_updated(node_id)
|
||||||
|
|
||||||
|
def _is_key_updatable_immediate(self, node_id):
|
||||||
|
"""True if the cache key set is an updatable type and it can be updated properly."""
|
||||||
|
return self.cache_key_set.is_key_updatable(node_id)
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
self.cache[cache_key] = value
|
if cache_key is not None:
|
||||||
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
def _get_immediate(self, node_id):
|
def _get_immediate(self, node_id):
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
return None
|
return None
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
if cache_key in self.cache:
|
if cache_key is not None and cache_key in self.cache:
|
||||||
return self.cache[cache_key]
|
return self.cache[cache_key]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -216,7 +323,7 @@ class BasicCache:
|
|||||||
if subcache is None:
|
if subcache is None:
|
||||||
subcache = BasicCache(self.key_class)
|
subcache = BasicCache(self.key_class)
|
||||||
self.subcaches[subcache_key] = subcache
|
self.subcaches[subcache_key] = subcache
|
||||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed)
|
||||||
return subcache
|
return subcache
|
||||||
|
|
||||||
def _get_subcache(self, node_id):
|
def _get_subcache(self, node_id):
|
||||||
@ -273,9 +380,23 @@ class HierarchicalCache(BasicCache):
|
|||||||
assert cache is not None
|
assert cache is not None
|
||||||
return await cache._ensure_subcache(node_id, children_ids)
|
return await cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
class NullCache:
|
async def update_cache_key(self, node_id):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
await cache._update_cache_key_immediate(node_id)
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
def is_key_updated(self, node_id):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
return cache._is_key_updated_immediate(node_id)
|
||||||
|
|
||||||
|
def is_key_updatable(self, node_id):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
return cache._is_key_updatable_immediate(node_id)
|
||||||
|
|
||||||
|
class NullCache:
|
||||||
|
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
@ -296,6 +417,15 @@ class NullCache:
|
|||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
async def update_cache_key(self, node_id):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_key_updated(self, node_id):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_key_updatable(self, node_id):
|
||||||
|
return False
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
class LRUCache(BasicCache):
|
||||||
def __init__(self, key_class, max_size=100):
|
def __init__(self, key_class, max_size=100):
|
||||||
super().__init__(key_class)
|
super().__init__(key_class)
|
||||||
@ -305,8 +435,8 @@ class LRUCache(BasicCache):
|
|||||||
self.used_generation = {}
|
self.used_generation = {}
|
||||||
self.children = {}
|
self.children = {}
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
await super().set_prompt(dynprompt, node_ids, is_changed)
|
||||||
self.generation += 1
|
self.generation += 1
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
@ -348,6 +478,18 @@ class LRUCache(BasicCache):
|
|||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
async def update_cache_key(self, node_id):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
await self._update_cache_key_immediate(node_id)
|
||||||
|
|
||||||
|
def is_key_updated(self, node_id):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
return self._is_key_updated_immediate(node_id)
|
||||||
|
|
||||||
|
def is_key_updatable(self, node_id):
|
||||||
|
self._mark_used(node_id)
|
||||||
|
return self._is_key_updatable_immediate(node_id)
|
||||||
|
|
||||||
|
|
||||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||||
@ -365,7 +507,6 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
|||||||
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||||
|
|
||||||
class RAMPressureCache(LRUCache):
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class):
|
||||||
super().__init__(key_class, 0)
|
super().__init__(key_class, 0)
|
||||||
self.timestamps = {}
|
self.timestamps = {}
|
||||||
|
|||||||
134
execution.py
134
execution.py
@ -18,7 +18,7 @@ import nodes
|
|||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
BasicCache,
|
BasicCache,
|
||||||
CacheKeySetID,
|
CacheKeySetID,
|
||||||
CacheKeySetInputSignature,
|
CacheKeySetUpdatableInputSignature,
|
||||||
NullCache,
|
NullCache,
|
||||||
HierarchicalCache,
|
HierarchicalCache,
|
||||||
LRUCache,
|
LRUCache,
|
||||||
@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
|
|||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
@ -46,49 +46,40 @@ class ExecutionResult(Enum):
|
|||||||
class DuplicateNodeError(Exception):
|
class DuplicateNodeError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class IsChangedCache:
|
class IsChanged:
|
||||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, execution_list: ExecutionList|None=None, extra_data: dict={}):
|
||||||
self.prompt_id = prompt_id
|
self.prompt_id = prompt_id
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.outputs_cache = outputs_cache
|
self.execution_list = execution_list
|
||||||
self.is_changed = {}
|
self.extra_data = extra_data
|
||||||
|
|
||||||
async def get(self, node_id):
|
|
||||||
if node_id in self.is_changed:
|
|
||||||
return self.is_changed[node_id]
|
|
||||||
|
|
||||||
|
def get_input_data(self, node_id):
|
||||||
|
node = self.dynprompt.get_node(node_id)
|
||||||
|
class_type = node["class_type"]
|
||||||
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
return get_input_data(node["inputs"], class_def, node_id, self.execution_list, self.dynprompt, self.extra_data)
|
||||||
|
|
||||||
|
async def get(self, node_id):
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
has_is_changed = False
|
|
||||||
is_changed_name = None
|
is_changed_name = None
|
||||||
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
|
||||||
has_is_changed = True
|
|
||||||
is_changed_name = "fingerprint_inputs"
|
is_changed_name = "fingerprint_inputs"
|
||||||
elif hasattr(class_def, "IS_CHANGED"):
|
elif hasattr(class_def, "IS_CHANGED"):
|
||||||
has_is_changed = True
|
|
||||||
is_changed_name = "IS_CHANGED"
|
is_changed_name = "IS_CHANGED"
|
||||||
if not has_is_changed:
|
if is_changed_name is None:
|
||||||
self.is_changed[node_id] = False
|
return False
|
||||||
return self.is_changed[node_id]
|
|
||||||
|
|
||||||
if "is_changed" in node:
|
input_data_all, _, v3_data = self.get_input_data(node_id)
|
||||||
self.is_changed[node_id] = node["is_changed"]
|
|
||||||
return self.is_changed[node_id]
|
|
||||||
|
|
||||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
|
||||||
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
|
|
||||||
try:
|
try:
|
||||||
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
|
||||||
is_changed = await resolve_map_node_over_list_results(is_changed)
|
is_changed = await resolve_map_node_over_list_results(is_changed)
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
is_changed = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("WARNING: {}".format(e))
|
logging.warning("WARNING: {}".format(e))
|
||||||
node["is_changed"] = float("NaN")
|
is_changed = float("NaN")
|
||||||
finally:
|
return is_changed
|
||||||
self.is_changed[node_id] = node["is_changed"]
|
|
||||||
return self.is_changed[node_id]
|
|
||||||
|
|
||||||
|
|
||||||
class CacheEntry(NamedTuple):
|
class CacheEntry(NamedTuple):
|
||||||
ui: dict
|
ui: dict
|
||||||
@ -118,19 +109,19 @@ class CacheSet:
|
|||||||
else:
|
else:
|
||||||
self.init_classic_cache()
|
self.init_classic_cache()
|
||||||
|
|
||||||
self.all = [self.outputs, self.objects]
|
self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects]
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetUpdatableInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetUpdatableInputSignature, max_size=cache_size)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_ram_cache(self, min_headroom):
|
def init_ram_cache(self, min_headroom):
|
||||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
self.outputs = RAMPressureCache(CacheKeySetUpdatableInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
@ -406,7 +397,10 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet,
|
||||||
|
current_item: str, extra_data: dict, executed: set, prompt_id: str,
|
||||||
|
execution_list: ExecutionList, pending_subgraph_results: dict,
|
||||||
|
pending_async_nodes: dict, ui_outputs: dict):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -414,16 +408,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = caches.outputs.get(unique_id)
|
|
||||||
if cached is not None:
|
if caches.outputs.is_key_updated(unique_id):
|
||||||
if server.client_id is not None:
|
# Key is updated, the cache can be checked.
|
||||||
cached_ui = cached.ui or {}
|
cached = caches.outputs.get(unique_id)
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
logging.debug(f"execute: {unique_id} cached: {cached is not None}")
|
||||||
if cached.ui is not None:
|
if cached is not None:
|
||||||
ui_outputs[unique_id] = cached.ui
|
if server.client_id is not None:
|
||||||
get_progress_state().finish_progress(unique_id)
|
cached_ui = cached.ui or {}
|
||||||
execution_list.cache_update(unique_id, cached)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
if cached.ui is not None:
|
||||||
|
ui_outputs[unique_id] = cached.ui
|
||||||
|
get_progress_state().finish_progress(unique_id)
|
||||||
|
execution_list.cache_update(unique_id, cached)
|
||||||
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
try:
|
try:
|
||||||
@ -464,11 +462,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
del pending_subgraph_results[unique_id]
|
del pending_subgraph_results[unique_id]
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
if caches.outputs.is_key_updated(unique_id):
|
||||||
|
# The key is updated, the node is executing.
|
||||||
|
get_progress_state().start_progress(unique_id)
|
||||||
|
if server.client_id is not None:
|
||||||
|
server.last_node_id = display_node_id
|
||||||
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
|
||||||
server.last_node_id = display_node_id
|
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
|
||||||
|
|
||||||
obj = caches.objects.get(unique_id)
|
obj = caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
@ -479,6 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||||
else:
|
else:
|
||||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||||
|
|
||||||
if lazy_status_present:
|
if lazy_status_present:
|
||||||
# for check_lazy_status, the returned data should include the original key of the input
|
# for check_lazy_status, the returned data should include the original key of the input
|
||||||
v3_data_lazy = v3_data.copy()
|
v3_data_lazy = v3_data.copy()
|
||||||
@ -494,6 +496,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
execution_list.make_input_strong_link(unique_id, i)
|
execution_list.make_input_strong_link(unique_id, i)
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
|
if not caches.outputs.is_key_updated(unique_id):
|
||||||
|
# Update the cache key after any lazy inputs are evaluated.
|
||||||
|
async def update_cache_key(node_id, unblock):
|
||||||
|
await caches.outputs.update_cache_key(node_id)
|
||||||
|
unblock()
|
||||||
|
asyncio.create_task(update_cache_key(unique_id, execution_list.add_external_block(unique_id)))
|
||||||
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
def execution_block_cb(block):
|
def execution_block_cb(block):
|
||||||
if block.message is not None:
|
if block.message is not None:
|
||||||
mes = {
|
mes = {
|
||||||
@ -525,6 +535,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
unblock()
|
unblock()
|
||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
ui_outputs[unique_id] = {
|
ui_outputs[unique_id] = {
|
||||||
"meta": {
|
"meta": {
|
||||||
@ -537,6 +548,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
}
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
cached_outputs = []
|
cached_outputs = []
|
||||||
new_node_ids = []
|
new_node_ids = []
|
||||||
@ -564,7 +576,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
new_node_ids = set(new_node_ids)
|
new_node_ids = set(new_node_ids)
|
||||||
for cache in caches.all:
|
for cache in caches.all:
|
||||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||||
subcache.clean_unused()
|
if subcache.clean_when == "before":
|
||||||
|
subcache.clean_unused()
|
||||||
for node_id in new_output_ids:
|
for node_id in new_output_ids:
|
||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
execution_list.cache_link(node_id, unique_id)
|
execution_list.cache_link(node_id, unique_id)
|
||||||
@ -689,25 +702,25 @@ class PromptExecutor:
|
|||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
reset_progress_state(prompt_id, dynamic_prompt)
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
|
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
|
||||||
cache.clean_unused()
|
if cache.clean_when == "before":
|
||||||
|
cache.clean_unused()
|
||||||
|
|
||||||
cached_nodes = []
|
if self.caches.outputs.clean_when == "before":
|
||||||
for node_id in prompt:
|
cached_nodes = []
|
||||||
if self.caches.outputs.get(node_id) is not None:
|
for node_id in prompt:
|
||||||
cached_nodes.append(node_id)
|
if self.caches.outputs.get(node_id) is not None:
|
||||||
|
cached_nodes.append(node_id)
|
||||||
|
self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
comfy.model_management.cleanup_models_gc()
|
||||||
self.add_message("execution_cached",
|
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
|
||||||
broadcast=False)
|
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
ui_node_outputs = {}
|
ui_node_outputs = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
for node_id in list(execute_outputs):
|
for node_id in list(execute_outputs):
|
||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
@ -746,6 +759,9 @@ class PromptExecutor:
|
|||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
|
for cache in self.caches.all:
|
||||||
|
if cache.clean_when == "after":
|
||||||
|
cache.clean_unused()
|
||||||
|
|
||||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user