Added a new type of cache key set.

This commit is contained in:
Deluxe233 2026-01-26 01:28:43 -05:00
parent 7ee77ff038
commit 232995856e
2 changed files with 272 additions and 115 deletions

View File

@ -3,6 +3,7 @@ import gc
import itertools import itertools
import psutil import psutil
import time import time
import logging
import torch import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
@ -14,7 +15,6 @@ from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool: def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
@ -23,9 +23,10 @@ def include_unique_id_in_input(class_type: str) -> bool:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet(ABC): class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
self.keys = {} self.keys = {}
self.subcache_keys = {} self.subcache_keys = {}
self.clean_when = None
@abstractmethod @abstractmethod
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
@ -46,6 +47,15 @@ class CacheKeySet(ABC):
def get_subcache_key(self, node_id): def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None) return self.subcache_keys.get(node_id, None)
async def update_cache_key(self, node_id) -> None:
pass
def is_key_updated(self, node_id) -> bool:
return True
def is_key_updatable(self, node_id) -> bool:
return False
class Unhashable: class Unhashable:
def __init__(self): def __init__(self):
self.value = float("NaN") self.value = float("NaN")
@ -63,10 +73,22 @@ def to_hashable(obj):
# TODO - Support other objects like tensors? # TODO - Support other objects like tensors?
return Unhashable() return Unhashable()
def throw_on_unhashable(obj):
# Same as to_hashable except throwing for unhashables instead.
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(throw_on_unhashable(k), throw_on_unhashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [throw_on_unhashable(i) for i in obj]))
else:
raise Exception("Object unhashable.")
class CacheKeySetID(CacheKeySet): class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.clean_when = "before"
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
@ -78,73 +100,143 @@ class CacheKeySetID(CacheKeySet):
self.keys[node_id] = (node_id, node["class_type"]) self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet): class CacheKeySetUpdatableInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt = dynprompt self.dynprompt: DynamicPrompt = dynprompt
self.is_changed_cache = is_changed_cache self.is_changed = is_changed
self.clean_when = "after"
self.updated_node_ids = set()
self.node_sig_cache = {}
"""Nodes' immediate node signatures."""
self.ancestry_cache = {}
"""List of a node's ancestors."""
def include_node_id_in_input(self) -> bool: def include_node_id_in_input(self) -> bool:
return False return False
async def update_cache_key(self, node_id):
"""Update key using cached outputs as part of the input signature."""
if node_id in self.updated_node_ids:
return
if node_id not in self.keys:
return
self.updated_node_ids.add(node_id)
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = await self.get_node_signature(node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def is_key_updated(self, node_id):
return node_id in self.updated_node_ids
def is_key_updatable(self, node_id):
_, missing_keys, _ = self.is_changed.get_input_data(node_id)
if missing_keys:
return False
return True
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
"""Initialize keys."""
for node_id in node_ids: for node_id in node_ids:
if node_id in self.keys: if node_id in self.keys:
continue continue
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
continue continue
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id) self.keys[node_id] = None
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
async def get_node_signature(self, dynprompt, node_id): async def get_node_signature(self, node_id):
signature = [] signatures = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) ancestors, order_mapping, input_hashes = self.get_ordered_ancestry(node_id)
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) self.node_sig_cache[node_id] = await self.get_immediate_node_signature(node_id, order_mapping, input_hashes)
for ancestor_id in ancestors: signatures.append(self.node_sig_cache[node_id])
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): for ancestor_id in ancestors:
if not dynprompt.has_node(node_id): assert ancestor_id in self.node_sig_cache
signatures.append(self.node_sig_cache[ancestor_id])
logging.debug(f"signature for {node_id}:\n{signatures}")
return to_hashable(signatures)
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict):
if not self.dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it. # This node doesn't exist -- we can't cache it.
return [float("NaN")] return [float("NaN")]
node = dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature = [class_type, await self.is_changed.get(node_id)]
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):
if is_link(inputs[key]): input = inputs[key]
(ancestor_id, ancestor_socket) = inputs[key] if is_link(input):
(ancestor_id, ancestor_socket) = input
ancestor_index = ancestor_order_mapping[ancestor_id] ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else: else:
signature.append((key, inputs[key])) signature.append((key, input))
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
return signature return signature
# This function returns a list of all ancestors of the given node. The order of the list is def get_ordered_ancestry(self, node_id):
# deterministic based on which specific inputs the ancestor is connected by. def get_ancestors(ancestors, ret: list=[]):
def get_ordered_ancestry(self, dynprompt, node_id): for ancestor_id in ancestors:
ancestors = [] if ancestor_id not in ret:
order_mapping = {} ret.append(ancestor_id)
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) get_ancestors(self.ancestry_cache[ancestor_id], ret)
return ancestors, order_mapping return ret
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): ancestors, input_hashes = self.get_ordered_ancestry_internal(node_id)
if not dynprompt.has_node(node_id): ancestors = get_ancestors(ancestors)
order_mapping = {}
for i, ancestor_id in enumerate(ancestors):
order_mapping[ancestor_id] = i
return ancestors, order_mapping, input_hashes
def get_ordered_ancestry_internal(self, node_id):
ancestors = []
input_hashes = {}
if node_id in self.ancestry_cache:
return self.ancestry_cache[node_id], input_hashes
if not self.dynprompt.has_node(node_id):
return return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys()) input_data_all, _, _ = self.is_changed.get_input_data(node_id)
for key in input_keys: inputs = self.dynprompt.get_node(node_id)["inputs"]
if is_link(inputs[key]): for key in sorted(inputs.keys()):
ancestor_id = inputs[key][0] input = inputs[key]
if ancestor_id not in order_mapping: if key in input_data_all:
ancestors.append(ancestor_id) if is_link(input):
order_mapping[ancestor_id] = len(ancestors) - 1 ancestor_id = input[0]
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) try:
# Replace link with input's hash
hashable = throw_on_unhashable(input_data_all[key])
input_hashes[key] = hash(hashable)
except:
# Link still needed
input_hashes[key] = input
if ancestor_id not in ancestors:
ancestors.append(ancestor_id)
else:
try:
hashable = throw_on_unhashable(input)
input_hashes[key] = hash(hashable)
except:
logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {input}")
input_hashes[key] = Unhashable()
self.ancestry_cache[node_id] = ancestors
return self.ancestry_cache[node_id], input_hashes
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class):
@ -154,12 +246,14 @@ class BasicCache:
self.cache_key_set: CacheKeySet self.cache_key_set: CacheKeySet
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
self.clean_when = "before"
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
await self.cache_key_set.add_keys(node_ids) await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache self.clean_when = self.cache_key_set.clean_when or "before"
self.is_changed = is_changed
self.initialized = True self.initialized = True
def all_node_ids(self): def all_node_ids(self):
@ -196,16 +290,29 @@ class BasicCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
async def _update_cache_key_immediate(self, node_id):
"""Update the cache key for the node."""
await self.cache_key_set.update_cache_key(node_id)
def _is_key_updated_immediate(self, node_id):
"""False if the cache key set is an updatable type and it hasn't been updated yet."""
return self.cache_key_set.is_key_updated(node_id)
def _is_key_updatable_immediate(self, node_id):
"""True if the cache key set is an updatable type and it can be updated properly."""
return self.cache_key_set.is_key_updatable(node_id)
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value if cache_key is not None:
self.cache[cache_key] = value
def _get_immediate(self, node_id): def _get_immediate(self, node_id):
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache: if cache_key is not None and cache_key in self.cache:
return self.cache[cache_key] return self.cache[cache_key]
else: else:
return None return None
@ -216,7 +323,7 @@ class BasicCache:
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed)
return subcache return subcache
def _get_subcache(self, node_id): def _get_subcache(self, node_id):
@ -273,9 +380,23 @@ class HierarchicalCache(BasicCache):
assert cache is not None assert cache is not None
return await cache._ensure_subcache(node_id, children_ids) return await cache._ensure_subcache(node_id, children_ids)
class NullCache: async def update_cache_key(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
await cache._update_cache_key_immediate(node_id)
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): def is_key_updated(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._is_key_updated_immediate(node_id)
def is_key_updatable(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._is_key_updatable_immediate(node_id)
class NullCache:
async def set_prompt(self, dynprompt, node_ids, is_changed):
pass pass
def all_node_ids(self): def all_node_ids(self):
@ -296,6 +417,15 @@ class NullCache:
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
return self return self
async def update_cache_key(self, node_id):
pass
def is_key_updated(self, node_id):
return True
def is_key_updatable(self, node_id):
return False
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100):
super().__init__(key_class) super().__init__(key_class)
@ -305,8 +435,8 @@ class LRUCache(BasicCache):
self.used_generation = {} self.used_generation = {}
self.children = {} self.children = {}
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed):
await super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed)
self.generation += 1 self.generation += 1
for node_id in node_ids: for node_id in node_ids:
self._mark_used(node_id) self._mark_used(node_id)
@ -348,6 +478,18 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
async def update_cache_key(self, node_id):
self._mark_used(node_id)
await self._update_cache_key_immediate(node_id)
def is_key_updated(self, node_id):
self._mark_used(node_id)
return self._is_key_updated_immediate(node_id)
def is_key_updatable(self, node_id):
self._mark_used(node_id)
return self._is_key_updatable_immediate(node_id)
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure #Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. #to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
@ -365,7 +507,6 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache): class RAMPressureCache(LRUCache):
def __init__(self, key_class): def __init__(self, key_class):
super().__init__(key_class, 0) super().__init__(key_class, 0)
self.timestamps = {} self.timestamps = {}

View File

@ -18,7 +18,7 @@ import nodes
from comfy_execution.caching import ( from comfy_execution.caching import (
BasicCache, BasicCache,
CacheKeySetID, CacheKeySetID,
CacheKeySetInputSignature, CacheKeySetUpdatableInputSignature,
NullCache, NullCache,
HierarchicalCache, HierarchicalCache,
LRUCache, LRUCache,
@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io from comfy_api.latest import io, _io
from server import PromptServer
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
@ -46,49 +46,40 @@ class ExecutionResult(Enum):
class DuplicateNodeError(Exception): class DuplicateNodeError(Exception):
pass pass
class IsChangedCache: class IsChanged:
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, execution_list: ExecutionList|None=None, extra_data: dict={}):
self.prompt_id = prompt_id self.prompt_id = prompt_id
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.outputs_cache = outputs_cache self.execution_list = execution_list
self.is_changed = {} self.extra_data = extra_data
async def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
def get_input_data(self, node_id):
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_data(node["inputs"], class_def, node_id, self.execution_list, self.dynprompt, self.extra_data)
async def get(self, node_id):
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
has_is_changed = False
is_changed_name = None is_changed_name = None
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
has_is_changed = True
is_changed_name = "fingerprint_inputs" is_changed_name = "fingerprint_inputs"
elif hasattr(class_def, "IS_CHANGED"): elif hasattr(class_def, "IS_CHANGED"):
has_is_changed = True
is_changed_name = "IS_CHANGED" is_changed_name = "IS_CHANGED"
if not has_is_changed: if is_changed_name is None:
self.is_changed[node_id] = False return False
return self.is_changed[node_id]
if "is_changed" in node: input_data_all, _, v3_data = self.get_input_data(node_id)
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] is_changed = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
logging.warning("WARNING: {}".format(e)) logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN") is_changed = float("NaN")
finally: return is_changed
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheEntry(NamedTuple): class CacheEntry(NamedTuple):
ui: dict ui: dict
@ -118,19 +109,19 @@ class CacheSet:
else: else:
self.init_classic_cache() self.init_classic_cache()
self.all = [self.outputs, self.objects] self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetUpdatableInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size): def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.outputs = LRUCache(CacheKeySetUpdatableInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom): def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.outputs = RAMPressureCache(CacheKeySetUpdatableInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
@ -406,7 +397,10 @@ def format_value(x):
else: else:
return str(x) return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet,
current_item: str, extra_data: dict, executed: set, prompt_id: str,
execution_list: ExecutionList, pending_subgraph_results: dict,
pending_async_nodes: dict, ui_outputs: dict):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -414,16 +408,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id)
if cached is not None: if caches.outputs.is_key_updated(unique_id):
if server.client_id is not None: # Key is updated, the cache can be checked.
cached_ui = cached.ui or {} cached = caches.outputs.get(unique_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) logging.debug(f"execute: {unique_id} cached: {cached is not None}")
if cached.ui is not None: if cached is not None:
ui_outputs[unique_id] = cached.ui if server.client_id is not None:
get_progress_state().finish_progress(unique_id) cached_ui = cached.ui or {}
execution_list.cache_update(unique_id, cached) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None) if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
try: try:
@ -464,11 +462,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
del pending_subgraph_results[unique_id] del pending_subgraph_results[unique_id]
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) if caches.outputs.is_key_updated(unique_id):
# The key is updated, the node is executing.
get_progress_state().start_progress(unique_id)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id) obj = caches.objects.get(unique_id)
if obj is None: if obj is None:
@ -479,6 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
# for check_lazy_status, the returned data should include the original key of the input # for check_lazy_status, the returned data should include the original key of the input
v3_data_lazy = v3_data.copy() v3_data_lazy = v3_data.copy()
@ -494,6 +496,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i) execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if not caches.outputs.is_key_updated(unique_id):
# Update the cache key after any lazy inputs are evaluated.
async def update_cache_key(node_id, unblock):
await caches.outputs.update_cache_key(node_id)
unblock()
asyncio.create_task(update_cache_key(unique_id, execution_list.add_external_block(unique_id)))
return (ExecutionResult.PENDING, None, None)
def execution_block_cb(block): def execution_block_cb(block):
if block.message is not None: if block.message is not None:
mes = { mes = {
@ -525,6 +535,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
unblock() unblock()
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
ui_outputs[unique_id] = { ui_outputs[unique_id] = {
"meta": { "meta": {
@ -537,6 +548,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
} }
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph: if has_subgraph:
cached_outputs = [] cached_outputs = []
new_node_ids = [] new_node_ids = []
@ -564,7 +576,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
new_node_ids = set(new_node_ids) new_node_ids = set(new_node_ids)
for cache in caches.all: for cache in caches.all:
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused() if subcache.clean_when == "before":
subcache.clean_unused()
for node_id in new_output_ids: for node_id in new_output_ids:
execution_list.add_node(node_id) execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id) execution_list.cache_link(node_id, unique_id)
@ -689,25 +702,25 @@ class PromptExecutor:
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server)) add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
for cache in self.caches.all: for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
cache.clean_unused() if cache.clean_when == "before":
cache.clean_unused()
cached_nodes = [] if self.caches.outputs.clean_when == "before":
for node_id in prompt: cached_nodes = []
if self.caches.outputs.get(node_id) is not None: for node_id in prompt:
cached_nodes.append(node_id) if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False)
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {} ui_node_outputs = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs): for node_id in list(execute_outputs):
execution_list.add_node(node_id) execution_list.add_node(node_id)
@ -746,6 +759,9 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
for cache in self.caches.all:
if cache.clean_when == "after":
cache.clean_unused()
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item unique_id = item