From 232995856e533df32e13c14ac00f3c74476fb9f7 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Mon, 26 Jan 2026 01:28:43 -0500 Subject: [PATCH 01/13] Added a new type of cache key set. --- comfy_execution/caching.py | 251 +++++++++++++++++++++++++++++-------- execution.py | 136 +++++++++++--------- 2 files changed, 272 insertions(+), 115 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..212a244b0 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -3,6 +3,7 @@ import gc import itertools import psutil import time +import logging import torch from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt @@ -14,7 +15,6 @@ from comfy_execution.graph_utils import is_link NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} - def include_unique_id_in_input(class_type: str) -> bool: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] @@ -23,9 +23,10 @@ def include_unique_id_in_input(class_type: str) -> bool: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] class CacheKeySet(ABC): - def __init__(self, dynprompt, node_ids, is_changed_cache): + def __init__(self, dynprompt, node_ids, is_changed): self.keys = {} self.subcache_keys = {} + self.clean_when = None @abstractmethod async def add_keys(self, node_ids): @@ -45,6 +46,15 @@ class CacheKeySet(ABC): def get_subcache_key(self, node_id): return self.subcache_keys.get(node_id, None) + + async def update_cache_key(self, node_id) -> None: + pass + + def is_key_updated(self, node_id) -> bool: + return True + + def is_key_updatable(self, node_id) -> bool: + return False class Unhashable: def __init__(self): @@ -62,11 +72,23 @@ def to_hashable(obj): else: # TODO - Support other objects like tensors? return Unhashable() + +def throw_on_unhashable(obj): + # Same as to_hashable except throwing for unhashables instead. + if isinstance(obj, (int, float, str, bool, bytes, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(throw_on_unhashable(k), throw_on_unhashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [throw_on_unhashable(i) for i in obj])) + else: + raise Exception("Object unhashable.") class CacheKeySetID(CacheKeySet): - def __init__(self, dynprompt, node_ids, is_changed_cache): - super().__init__(dynprompt, node_ids, is_changed_cache) + def __init__(self, dynprompt, node_ids, is_changed): + super().__init__(dynprompt, node_ids, is_changed) self.dynprompt = dynprompt + self.clean_when = "before" async def add_keys(self, node_ids): for node_id in node_ids: @@ -78,73 +100,143 @@ class CacheKeySetID(CacheKeySet): self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) -class CacheKeySetInputSignature(CacheKeySet): - def __init__(self, dynprompt, node_ids, is_changed_cache): - super().__init__(dynprompt, node_ids, is_changed_cache) - self.dynprompt = dynprompt - self.is_changed_cache = is_changed_cache +class CacheKeySetUpdatableInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed): + super().__init__(dynprompt, node_ids, is_changed) + self.dynprompt: DynamicPrompt = dynprompt + self.is_changed = is_changed + self.clean_when = "after" + + self.updated_node_ids = set() + self.node_sig_cache = {} + """Nodes' immediate node signatures.""" + self.ancestry_cache = {} + """List of a node's ancestors.""" def include_node_id_in_input(self) -> bool: return False + + async def update_cache_key(self, node_id): + """Update key using cached outputs as part of the input signature.""" + if node_id in self.updated_node_ids: + return + if node_id not in self.keys: + return + self.updated_node_ids.add(node_id) + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = await self.get_node_signature(node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def is_key_updated(self, node_id): + return node_id in self.updated_node_ids + + def is_key_updatable(self, node_id): + _, missing_keys, _ = self.is_changed.get_input_data(node_id) + if missing_keys: + return False + return True async def add_keys(self, node_ids): + """Initialize keys.""" for node_id in node_ids: if node_id in self.keys: continue if not self.dynprompt.has_node(node_id): continue node = self.dynprompt.get_node(node_id) - self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id) + self.keys[node_id] = None self.subcache_keys[node_id] = (node_id, node["class_type"]) - async def get_node_signature(self, dynprompt, node_id): - signature = [] - ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) - signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) - for ancestor_id in ancestors: - signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) - return to_hashable(signature) + async def get_node_signature(self, node_id): + signatures = [] + ancestors, order_mapping, input_hashes = self.get_ordered_ancestry(node_id) + self.node_sig_cache[node_id] = await self.get_immediate_node_signature(node_id, order_mapping, input_hashes) + signatures.append(self.node_sig_cache[node_id]) - async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): - if not dynprompt.has_node(node_id): + for ancestor_id in ancestors: + assert ancestor_id in self.node_sig_cache + signatures.append(self.node_sig_cache[ancestor_id]) + + logging.debug(f"signature for {node_id}:\n{signatures}") + return to_hashable(signatures) + + async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict): + if not self.dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] - node = dynprompt.get_node(node_id) + node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, await self.is_changed_cache.get(node_id)] - if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): - signature.append(node_id) - inputs = node["inputs"] + + signature = [class_type, await self.is_changed.get(node_id)] + for key in sorted(inputs.keys()): - if is_link(inputs[key]): - (ancestor_id, ancestor_socket) = inputs[key] + input = inputs[key] + if is_link(input): + (ancestor_id, ancestor_socket) = input ancestor_index = ancestor_order_mapping[ancestor_id] signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: - signature.append((key, inputs[key])) + signature.append((key, input)) + + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): + signature.append(node_id) + return signature + + def get_ordered_ancestry(self, node_id): + def get_ancestors(ancestors, ret: list=[]): + for ancestor_id in ancestors: + if ancestor_id not in ret: + ret.append(ancestor_id) + get_ancestors(self.ancestry_cache[ancestor_id], ret) + return ret + + ancestors, input_hashes = self.get_ordered_ancestry_internal(node_id) + ancestors = get_ancestors(ancestors) - # This function returns a list of all ancestors of the given node. The order of the list is - # deterministic based on which specific inputs the ancestor is connected by. - def get_ordered_ancestry(self, dynprompt, node_id): - ancestors = [] order_mapping = {} - self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) - return ancestors, order_mapping + for i, ancestor_id in enumerate(ancestors): + order_mapping[ancestor_id] = i + + return ancestors, order_mapping, input_hashes - def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): - if not dynprompt.has_node(node_id): + def get_ordered_ancestry_internal(self, node_id): + ancestors = [] + input_hashes = {} + + if node_id in self.ancestry_cache: + return self.ancestry_cache[node_id], input_hashes + + if not self.dynprompt.has_node(node_id): return - inputs = dynprompt.get_node(node_id)["inputs"] - input_keys = sorted(inputs.keys()) - for key in input_keys: - if is_link(inputs[key]): - ancestor_id = inputs[key][0] - if ancestor_id not in order_mapping: - ancestors.append(ancestor_id) - order_mapping[ancestor_id] = len(ancestors) - 1 - self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + + input_data_all, _, _ = self.is_changed.get_input_data(node_id) + inputs = self.dynprompt.get_node(node_id)["inputs"] + for key in sorted(inputs.keys()): + input = inputs[key] + if key in input_data_all: + if is_link(input): + ancestor_id = input[0] + try: + # Replace link with input's hash + hashable = throw_on_unhashable(input_data_all[key]) + input_hashes[key] = hash(hashable) + except: + # Link still needed + input_hashes[key] = input + if ancestor_id not in ancestors: + ancestors.append(ancestor_id) + else: + try: + hashable = throw_on_unhashable(input) + input_hashes[key] = hash(hashable) + except: + logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {input}") + input_hashes[key] = Unhashable() + + self.ancestry_cache[node_id] = ancestors + return self.ancestry_cache[node_id], input_hashes class BasicCache: def __init__(self, key_class): @@ -154,12 +246,14 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} + self.clean_when = "before" - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + async def set_prompt(self, dynprompt, node_ids, is_changed): self.dynprompt = dynprompt - self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed) await self.cache_key_set.add_keys(node_ids) - self.is_changed_cache = is_changed_cache + self.clean_when = self.cache_key_set.clean_when or "before" + self.is_changed = is_changed self.initialized = True def all_node_ids(self): @@ -196,16 +290,29 @@ class BasicCache: def poll(self, **kwargs): pass + async def _update_cache_key_immediate(self, node_id): + """Update the cache key for the node.""" + await self.cache_key_set.update_cache_key(node_id) + + def _is_key_updated_immediate(self, node_id): + """False if the cache key set is an updatable type and it hasn't been updated yet.""" + return self.cache_key_set.is_key_updated(node_id) + + def _is_key_updatable_immediate(self, node_id): + """True if the cache key set is an updatable type and it can be updated properly.""" + return self.cache_key_set.is_key_updatable(node_id) + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value + if cache_key is not None: + self.cache[cache_key] = value def _get_immediate(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: + if cache_key is not None and cache_key in self.cache: return self.cache[cache_key] else: return None @@ -216,7 +323,7 @@ class BasicCache: if subcache is None: subcache = BasicCache(self.key_class) self.subcaches[subcache_key] = subcache - await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed) return subcache def _get_subcache(self, node_id): @@ -272,10 +379,24 @@ class HierarchicalCache(BasicCache): cache = self._get_cache_for(node_id) assert cache is not None return await cache._ensure_subcache(node_id, children_ids) + + async def update_cache_key(self, node_id): + cache = self._get_cache_for(node_id) + assert cache is not None + await cache._update_cache_key_immediate(node_id) + + def is_key_updated(self, node_id): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._is_key_updated_immediate(node_id) + + def is_key_updatable(self, node_id): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._is_key_updatable_immediate(node_id) class NullCache: - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + async def set_prompt(self, dynprompt, node_ids, is_changed): pass def all_node_ids(self): @@ -295,6 +416,15 @@ class NullCache: async def ensure_subcache_for(self, node_id, children_ids): return self + + async def update_cache_key(self, node_id): + pass + + def is_key_updated(self, node_id): + return True + + def is_key_updatable(self, node_id): + return False class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): @@ -305,8 +435,8 @@ class LRUCache(BasicCache): self.used_generation = {} self.children = {} - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - await super().set_prompt(dynprompt, node_ids, is_changed_cache) + async def set_prompt(self, dynprompt, node_ids, is_changed): + await super().set_prompt(dynprompt, node_ids, is_changed) self.generation += 1 for node_id in node_ids: self._mark_used(node_id) @@ -347,6 +477,18 @@ class LRUCache(BasicCache): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + + async def update_cache_key(self, node_id): + self._mark_used(node_id) + await self._update_cache_key_immediate(node_id) + + def is_key_updated(self, node_id): + self._mark_used(node_id) + return self._is_key_updated_immediate(node_id) + + def is_key_updatable(self, node_id): + self._mark_used(node_id) + return self._is_key_updatable_immediate(node_id) #Iterating the cache for usage analysis might be expensive, so if we trigger make sure @@ -365,7 +507,6 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): super().__init__(key_class, 0) self.timestamps = {} diff --git a/execution.py b/execution.py index 4b4f63c80..5b21a4228 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ import nodes from comfy_execution.caching import ( BasicCache, CacheKeySetID, - CacheKeySetInputSignature, + CacheKeySetUpdatableInputSignature, NullCache, HierarchicalCache, LRUCache, @@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io - +from server import PromptServer class ExecutionResult(Enum): SUCCESS = 0 @@ -46,49 +46,40 @@ class ExecutionResult(Enum): class DuplicateNodeError(Exception): pass -class IsChangedCache: - def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): +class IsChanged: + def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, execution_list: ExecutionList|None=None, extra_data: dict={}): self.prompt_id = prompt_id self.dynprompt = dynprompt - self.outputs_cache = outputs_cache - self.is_changed = {} - - async def get(self, node_id): - if node_id in self.is_changed: - return self.is_changed[node_id] + self.execution_list = execution_list + self.extra_data = extra_data + def get_input_data(self, node_id): + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_data(node["inputs"], class_def, node_id, self.execution_list, self.dynprompt, self.extra_data) + + async def get(self, node_id): node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - has_is_changed = False is_changed_name = None if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: - has_is_changed = True is_changed_name = "fingerprint_inputs" elif hasattr(class_def, "IS_CHANGED"): - has_is_changed = True is_changed_name = "IS_CHANGED" - if not has_is_changed: - self.is_changed[node_id] = False - return self.is_changed[node_id] + if is_changed_name is None: + return False - if "is_changed" in node: - self.is_changed[node_id] = node["is_changed"] - return self.is_changed[node_id] - - # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = self.get_input_data(node_id) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) - node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + is_changed = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: logging.warning("WARNING: {}".format(e)) - node["is_changed"] = float("NaN") - finally: - self.is_changed[node_id] = node["is_changed"] - return self.is_changed[node_id] - + is_changed = float("NaN") + return is_changed class CacheEntry(NamedTuple): ui: dict @@ -118,19 +109,19 @@ class CacheSet: else: self.init_classic_cache() - self.all = [self.outputs, self.objects] + self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.outputs = HierarchicalCache(CacheKeySetUpdatableInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.outputs = LRUCache(CacheKeySetUpdatableInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetUpdatableInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -406,7 +397,10 @@ def format_value(x): else: return str(x) -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): +async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet, + current_item: str, extra_data: dict, executed: set, prompt_id: str, + execution_list: ExecutionList, pending_subgraph_results: dict, + pending_async_nodes: dict, ui_outputs: dict): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -414,16 +408,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = caches.outputs.get(unique_id) - if cached is not None: - if server.client_id is not None: - cached_ui = cached.ui or {} - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) - if cached.ui is not None: - ui_outputs[unique_id] = cached.ui - get_progress_state().finish_progress(unique_id) - execution_list.cache_update(unique_id, cached) - return (ExecutionResult.SUCCESS, None, None) + + if caches.outputs.is_key_updated(unique_id): + # Key is updated, the cache can be checked. + cached = caches.outputs.get(unique_id) + logging.debug(f"execute: {unique_id} cached: {cached is not None}") + if cached is not None: + if server.client_id is not None: + cached_ui = cached.ui or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui + get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, cached) + return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: @@ -464,11 +462,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, del pending_subgraph_results[unique_id] has_subgraph = False else: - get_progress_state().start_progress(unique_id) + if caches.outputs.is_key_updated(unique_id): + # The key is updated, the node is executing. + get_progress_state().start_progress(unique_id) + if server.client_id is not None: + server.last_node_id = display_node_id + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) - if server.client_id is not None: - server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) obj = caches.objects.get(unique_id) if obj is None: @@ -479,6 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None + if lazy_status_present: # for check_lazy_status, the returned data should include the original key of the input v3_data_lazy = v3_data.copy() @@ -494,6 +496,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, execution_list.make_input_strong_link(unique_id, i) return (ExecutionResult.PENDING, None, None) + if not caches.outputs.is_key_updated(unique_id): + # Update the cache key after any lazy inputs are evaluated. + async def update_cache_key(node_id, unblock): + await caches.outputs.update_cache_key(node_id) + unblock() + asyncio.create_task(update_cache_key(unique_id, execution_list.add_external_block(unique_id))) + return (ExecutionResult.PENDING, None, None) + def execution_block_cb(block): if block.message is not None: mes = { @@ -525,6 +535,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, unblock() asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) + if len(output_ui) > 0: ui_outputs[unique_id] = { "meta": { @@ -537,6 +548,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, } if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -564,7 +576,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, new_node_ids = set(new_node_ids) for cache in caches.all: subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) - subcache.clean_unused() + if subcache.clean_when == "before": + subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) execution_list.cache_link(node_id, unique_id) @@ -689,25 +702,25 @@ class PromptExecutor: dynamic_prompt = DynamicPrompt(prompt) reset_progress_state(prompt_id, dynamic_prompt) add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data) for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed) + if cache.clean_when == "before": + cache.clean_unused() - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + if self.caches.outputs.clean_when == "before": + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) + self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results ui_node_outputs = {} executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id) @@ -745,7 +758,10 @@ class PromptExecutor: self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() - + + for cache in self.caches.all: + if cache.clean_when == "after": + cache.clean_unused() async def validate_inputs(prompt_id, prompt, item, validated): unique_id = item From 38ab4e3c76767bbc573a048e6073139033611893 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Mon, 26 Jan 2026 03:50:50 -0500 Subject: [PATCH 02/13] Fixed not taking rawLink into account. --- comfy_execution/caching.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 212a244b0..a3f3ac338 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -157,7 +157,7 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): assert ancestor_id in self.node_sig_cache signatures.append(self.node_sig_cache[ancestor_id]) - logging.debug(f"signature for {node_id}:\n{signatures}") + logging.debug(f"signature for node {node_id}: {signatures}") return to_hashable(signatures) async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict): @@ -202,6 +202,11 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): return ancestors, order_mapping, input_hashes def get_ordered_ancestry_internal(self, node_id): + def get_hashable(obj): + try: + return throw_on_unhashable(obj) + except: + return Unhashable ancestors = [] input_hashes = {} @@ -214,26 +219,25 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): input_data_all, _, _ = self.is_changed.get_input_data(node_id) inputs = self.dynprompt.get_node(node_id)["inputs"] for key in sorted(inputs.keys()): - input = inputs[key] if key in input_data_all: - if is_link(input): - ancestor_id = input[0] - try: - # Replace link with input's hash - hashable = throw_on_unhashable(input_data_all[key]) - input_hashes[key] = hash(hashable) - except: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + hashable = get_hashable(input_data_all[key]) + if hashable is Unhashable or is_link(input_data_all[key]): # Link still needed - input_hashes[key] = input + input_hashes[key] = inputs[key] if ancestor_id not in ancestors: ancestors.append(ancestor_id) - else: - try: - hashable = throw_on_unhashable(input) + else: + # Replace link with input's hash input_hashes[key] = hash(hashable) - except: - logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {input}") + else: + hashable = get_hashable(inputs[key]) + if hashable is Unhashable: + logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}") input_hashes[key] = Unhashable() + else: + input_hashes[key] = hash(hashable) self.ancestry_cache[node_id] = ancestors return self.ancestry_cache[node_id], input_hashes From 4683136740a048ecd6b6d721481bdaf55f83b07d Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Mon, 26 Jan 2026 09:33:00 -0500 Subject: [PATCH 03/13] Update caching.py --- comfy_execution/caching.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index a3f3ac338..751c76e9f 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -123,9 +123,7 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): if node_id not in self.keys: return self.updated_node_ids.add(node_id) - node = self.dynprompt.get_node(node_id) self.keys[node_id] = await self.get_node_signature(node_id) - self.subcache_keys[node_id] = (node_id, node["class_type"]) def is_key_updated(self, node_id): return node_id in self.updated_node_ids @@ -149,16 +147,17 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): async def get_node_signature(self, node_id): signatures = [] - ancestors, order_mapping, input_hashes = self.get_ordered_ancestry(node_id) - self.node_sig_cache[node_id] = await self.get_immediate_node_signature(node_id, order_mapping, input_hashes) + ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id) + self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs)) signatures.append(self.node_sig_cache[node_id]) for ancestor_id in ancestors: assert ancestor_id in self.node_sig_cache signatures.append(self.node_sig_cache[ancestor_id]) - + + signatures = frozenset(zip(itertools.count(), signatures)) logging.debug(f"signature for node {node_id}: {signatures}") - return to_hashable(signatures) + return signatures async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict): if not self.dynprompt.has_node(node_id): @@ -192,14 +191,14 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): get_ancestors(self.ancestry_cache[ancestor_id], ret) return ret - ancestors, input_hashes = self.get_ordered_ancestry_internal(node_id) + ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id) ancestors = get_ancestors(ancestors) order_mapping = {} for i, ancestor_id in enumerate(ancestors): order_mapping[ancestor_id] = i - return ancestors, order_mapping, input_hashes + return ancestors, order_mapping, node_inputs def get_ordered_ancestry_internal(self, node_id): def get_hashable(obj): @@ -207,14 +206,15 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): return throw_on_unhashable(obj) except: return Unhashable + ancestors = [] - input_hashes = {} + node_inputs = {} if node_id in self.ancestry_cache: - return self.ancestry_cache[node_id], input_hashes + return self.ancestry_cache[node_id], node_inputs if not self.dynprompt.has_node(node_id): - return + return ancestors, node_inputs input_data_all, _, _ = self.is_changed.get_input_data(node_id) inputs = self.dynprompt.get_node(node_id)["inputs"] @@ -225,22 +225,22 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): hashable = get_hashable(input_data_all[key]) if hashable is Unhashable or is_link(input_data_all[key]): # Link still needed - input_hashes[key] = inputs[key] + node_inputs[key] = inputs[key] if ancestor_id not in ancestors: ancestors.append(ancestor_id) else: - # Replace link with input's hash - input_hashes[key] = hash(hashable) + # Replace link + node_inputs[key] = input_data_all[key] else: hashable = get_hashable(inputs[key]) if hashable is Unhashable: logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}") - input_hashes[key] = Unhashable() + node_inputs[key] = Unhashable() else: - input_hashes[key] = hash(hashable) + node_inputs[key] = inputs[key] self.ancestry_cache[node_id] = ancestors - return self.ancestry_cache[node_id], input_hashes + return self.ancestry_cache[node_id], node_inputs class BasicCache: def __init__(self, key_class): From 1107f4322b154b18dc5f828df4f201705320179c Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Mon, 26 Jan 2026 09:40:45 -0500 Subject: [PATCH 04/13] Removed unused method --- comfy_execution/caching.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 751c76e9f..8fcd19b00 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -52,9 +52,6 @@ class CacheKeySet(ABC): def is_key_updated(self, node_id) -> bool: return True - - def is_key_updatable(self, node_id) -> bool: - return False class Unhashable: def __init__(self): @@ -128,12 +125,6 @@ class CacheKeySetUpdatableInputSignature(CacheKeySet): def is_key_updated(self, node_id): return node_id in self.updated_node_ids - def is_key_updatable(self, node_id): - _, missing_keys, _ = self.is_changed.get_input_data(node_id) - if missing_keys: - return False - return True - async def add_keys(self, node_ids): """Initialize keys.""" for node_id in node_ids: @@ -301,10 +292,6 @@ class BasicCache: def _is_key_updated_immediate(self, node_id): """False if the cache key set is an updatable type and it hasn't been updated yet.""" return self.cache_key_set.is_key_updated(node_id) - - def _is_key_updatable_immediate(self, node_id): - """True if the cache key set is an updatable type and it can be updated properly.""" - return self.cache_key_set.is_key_updatable(node_id) def _set_immediate(self, node_id, value): assert self.initialized @@ -394,11 +381,6 @@ class HierarchicalCache(BasicCache): assert cache is not None return cache._is_key_updated_immediate(node_id) - def is_key_updatable(self, node_id): - cache = self._get_cache_for(node_id) - assert cache is not None - return cache._is_key_updatable_immediate(node_id) - class NullCache: async def set_prompt(self, dynprompt, node_ids, is_changed): pass @@ -427,9 +409,6 @@ class NullCache: def is_key_updated(self, node_id): return True - def is_key_updatable(self, node_id): - return False - class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -490,10 +469,6 @@ class LRUCache(BasicCache): self._mark_used(node_id) return self._is_key_updated_immediate(node_id) - def is_key_updatable(self, node_id): - self._mark_used(node_id) - return self._is_key_updatable_immediate(node_id) - #Iterating the cache for usage analysis might be expensive, so if we trigger make sure #to take a chunk out to give breathing space on high-node / low-ram-per-node flows. From f511703343072ad1a2b4029c267b36f7f4addec7 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Tue, 27 Jan 2026 13:51:41 -0500 Subject: [PATCH 05/13] Included original cache key set for testing --- comfy_execution/caching.py | 105 +++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 8fcd19b00..375d6a36b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -97,6 +97,111 @@ class CacheKeySetID(CacheKeySet): self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) +class CacheKeySetInputSignatureOriginalConstant(CacheKeySet): + """Original CacheKeySet""" + def __init__(self, dynprompt, node_ids, is_changed): + super().__init__(dynprompt, node_ids, is_changed) + self.dynprompt = dynprompt + self.is_changed = is_changed + self.clean_when = "before" + self.node_sig_cache = {} + + def include_node_id_in_input(self) -> bool: + return False + + async def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + if not self.dynprompt.has_node(node_id): + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = await self.get_node_signature(node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + async def get_node_signature(self, node_id): + signatures = [] + ancestors, order_mapping = self.get_ordered_ancestry(node_id) + if node_id not in self.node_sig_cache: + self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping)) + signatures.append(self.node_sig_cache[node_id]) + for ancestor_id in ancestors: + if ancestor_id not in self.node_sig_cache: + self.node_sig_cache[ancestor_id] = to_hashable(await self.get_immediate_node_signature(ancestor_id, order_mapping)) + signatures.append(self.node_sig_cache[ancestor_id]) + signatures = frozenset(zip(itertools.count(), signatures)) + logging.debug(f"signature for node {node_id}: {signatures}") + return signatures + + async def get_immediate_node_signature(self, node_id, ancestor_order_mapping): + if not self.dynprompt.has_node(node_id): + # This node doesn't exist -- we can't cache it. + return [float("NaN")] + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, await self.is_changed.get(node_id)] + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): + signature.append(node_id) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, node_id, ancestors, order_mapping): + if not self.dynprompt.has_node(node_id): + return + inputs = self.dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(ancestor_id, ancestors, order_mapping) + +class CacheKeySetInputSignatureOriginalUpdatable(CacheKeySetInputSignatureOriginalConstant): + """Original constant CacheKeySet modified to be updatable.""" + def __init__(self, dynprompt, node_ids, is_changed): + super().__init__(dynprompt, node_ids, is_changed) + self.clean_when = "after" + self.updated_node_ids = set() + + async def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + if not self.dynprompt.has_node(node_id): + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = None + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + async def update_cache_key(self, node_id): + if node_id in self.updated_node_ids: + return + if node_id not in self.keys: + return + self.updated_node_ids.add(node_id) + self.keys[node_id] = await self.get_node_signature(node_id) + + def is_key_updated(self, node_id): + return node_id in self.updated_node_ids + class CacheKeySetUpdatableInputSignature(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) From af4d691d1ffacb722b9d671782f46d93d4ccec8e Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Tue, 27 Jan 2026 20:57:44 -0500 Subject: [PATCH 06/13] Revert "Included original cache key set for testing" This reverts commit f511703343072ad1a2b4029c267b36f7f4addec7. --- comfy_execution/caching.py | 105 ------------------------------------- 1 file changed, 105 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 375d6a36b..8fcd19b00 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -97,111 +97,6 @@ class CacheKeySetID(CacheKeySet): self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) -class CacheKeySetInputSignatureOriginalConstant(CacheKeySet): - """Original CacheKeySet""" - def __init__(self, dynprompt, node_ids, is_changed): - super().__init__(dynprompt, node_ids, is_changed) - self.dynprompt = dynprompt - self.is_changed = is_changed - self.clean_when = "before" - self.node_sig_cache = {} - - def include_node_id_in_input(self) -> bool: - return False - - async def add_keys(self, node_ids): - for node_id in node_ids: - if node_id in self.keys: - continue - if not self.dynprompt.has_node(node_id): - continue - node = self.dynprompt.get_node(node_id) - self.keys[node_id] = await self.get_node_signature(node_id) - self.subcache_keys[node_id] = (node_id, node["class_type"]) - - async def get_node_signature(self, node_id): - signatures = [] - ancestors, order_mapping = self.get_ordered_ancestry(node_id) - if node_id not in self.node_sig_cache: - self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping)) - signatures.append(self.node_sig_cache[node_id]) - for ancestor_id in ancestors: - if ancestor_id not in self.node_sig_cache: - self.node_sig_cache[ancestor_id] = to_hashable(await self.get_immediate_node_signature(ancestor_id, order_mapping)) - signatures.append(self.node_sig_cache[ancestor_id]) - signatures = frozenset(zip(itertools.count(), signatures)) - logging.debug(f"signature for node {node_id}: {signatures}") - return signatures - - async def get_immediate_node_signature(self, node_id, ancestor_order_mapping): - if not self.dynprompt.has_node(node_id): - # This node doesn't exist -- we can't cache it. - return [float("NaN")] - node = self.dynprompt.get_node(node_id) - class_type = node["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, await self.is_changed.get(node_id)] - inputs = node["inputs"] - for key in sorted(inputs.keys()): - if is_link(inputs[key]): - (ancestor_id, ancestor_socket) = inputs[key] - ancestor_index = ancestor_order_mapping[ancestor_id] - signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) - else: - signature.append((key, inputs[key])) - if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): - signature.append(node_id) - return signature - - # This function returns a list of all ancestors of the given node. The order of the list is - # deterministic based on which specific inputs the ancestor is connected by. - def get_ordered_ancestry(self, node_id): - ancestors = [] - order_mapping = {} - self.get_ordered_ancestry_internal(node_id, ancestors, order_mapping) - return ancestors, order_mapping - - def get_ordered_ancestry_internal(self, node_id, ancestors, order_mapping): - if not self.dynprompt.has_node(node_id): - return - inputs = self.dynprompt.get_node(node_id)["inputs"] - input_keys = sorted(inputs.keys()) - for key in input_keys: - if is_link(inputs[key]): - ancestor_id = inputs[key][0] - if ancestor_id not in order_mapping: - ancestors.append(ancestor_id) - order_mapping[ancestor_id] = len(ancestors) - 1 - self.get_ordered_ancestry_internal(ancestor_id, ancestors, order_mapping) - -class CacheKeySetInputSignatureOriginalUpdatable(CacheKeySetInputSignatureOriginalConstant): - """Original constant CacheKeySet modified to be updatable.""" - def __init__(self, dynprompt, node_ids, is_changed): - super().__init__(dynprompt, node_ids, is_changed) - self.clean_when = "after" - self.updated_node_ids = set() - - async def add_keys(self, node_ids): - for node_id in node_ids: - if node_id in self.keys: - continue - if not self.dynprompt.has_node(node_id): - continue - node = self.dynprompt.get_node(node_id) - self.keys[node_id] = None - self.subcache_keys[node_id] = (node_id, node["class_type"]) - - async def update_cache_key(self, node_id): - if node_id in self.updated_node_ids: - return - if node_id not in self.keys: - return - self.updated_node_ids.add(node_id) - self.keys[node_id] = await self.get_node_signature(node_id) - - def is_key_updated(self, node_id): - return node_id in self.updated_node_ids - class CacheKeySetUpdatableInputSignature(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) From b951181123c59d43a3b798fd8859f9d30dd02b7c Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Wed, 28 Jan 2026 10:37:10 -0500 Subject: [PATCH 07/13] Added tests + cleanup --- comfy_execution/caching.py | 7 +--- execution.py | 24 +++-------- tests/execution/test_execution.py | 40 ++++++++++++++----- .../testing-pack/specific_tests.py | 35 ++++++++++++---- 4 files changed, 65 insertions(+), 41 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 8fcd19b00..a579982a0 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -26,7 +26,6 @@ class CacheKeySet(ABC): def __init__(self, dynprompt, node_ids, is_changed): self.keys = {} self.subcache_keys = {} - self.clean_when = None @abstractmethod async def add_keys(self, node_ids): @@ -85,7 +84,6 @@ class CacheKeySetID(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) self.dynprompt = dynprompt - self.clean_when = "before" async def add_keys(self, node_ids): for node_id in node_ids: @@ -97,12 +95,11 @@ class CacheKeySetID(CacheKeySet): self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) -class CacheKeySetUpdatableInputSignature(CacheKeySet): +class CacheKeySetInputSignature(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) self.dynprompt: DynamicPrompt = dynprompt self.is_changed = is_changed - self.clean_when = "after" self.updated_node_ids = set() self.node_sig_cache = {} @@ -241,13 +238,11 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} - self.clean_when = "before" async def set_prompt(self, dynprompt, node_ids, is_changed): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed) await self.cache_key_set.add_keys(node_ids) - self.clean_when = self.cache_key_set.clean_when or "before" self.is_changed = is_changed self.initialized = True diff --git a/execution.py b/execution.py index 5b21a4228..4337272b0 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ import nodes from comfy_execution.caching import ( BasicCache, CacheKeySetID, - CacheKeySetUpdatableInputSignature, + CacheKeySetInputSignature, NullCache, HierarchicalCache, LRUCache, @@ -113,15 +113,15 @@ class CacheSet: # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetUpdatableInputSignature) + self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetUpdatableInputSignature, max_size=cache_size) + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetUpdatableInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -575,9 +575,7 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: - subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) - if subcache.clean_when == "before": - subcache.clean_unused() + await cache.ensure_subcache_for(unique_id, new_node_ids) for node_id in new_output_ids: execution_list.add_node(node_id) execution_list.cache_link(node_id, unique_id) @@ -706,15 +704,6 @@ class PromptExecutor: is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data) for cache in self.caches.all: await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed) - if cache.clean_when == "before": - cache.clean_unused() - - if self.caches.outputs.clean_when == "before": - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) - self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) comfy.model_management.cleanup_models_gc() pending_subgraph_results = {} @@ -760,8 +749,7 @@ class PromptExecutor: comfy.model_management.unload_all_models() for cache in self.caches.all: - if cache.clean_when == "after": - cache.clean_unused() + cache.clean_unused() async def validate_inputs(prompt_id, prompt, item, validated): unique_id = item diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..10cb4216a 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -552,27 +552,47 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" - # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): + def test_is_changed_passed_cached_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) - test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) - + test_node = g.node("TestIsChangedWithAllInputs", image=input1.out(0), value=0.5) output = g.node("PreviewImage", images=test_node.out(0)) - result = client.run(g) - images = result.get_images(output) + result1 = client.run(g) + images = result1.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - result = client.run(g) - images = result.get_images(output) + result2 = client.run(g) + images = result2.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" + if server["should_cache_results"]: - assert not result.did_run(test_node), "The execution should have been cached" + assert not result2.did_run(test_node), "Test node should not have run again" else: - assert result.did_run(test_node), "The execution should have been re-run" + assert result2.did_run(test_node), "Test node should always run here" + + def test_dont_always_run_downstream(self, client: ComfyClient, builder: GraphBuilder, server): + g = builder + float1 = g.node("TestDontAlwaysRunDownstream", float=0.5) # IS_CHANGED returns float("NaN") + image1 = g.node("StubConstantImage", value=float1.out(0), height=512, width=512, batch_size=1) + output = g.node("PreviewImage", images=image1.out(0)) + + result1 = client.run(g) + images = result1.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50" + + result2 = client.run(g) + images = result2.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50" + + if server["should_cache_results"]: + assert not result2.did_run(output), "Output node should not have run the second time" + else: + assert result2.did_run(output), "Output node should always run here" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): diff --git a/tests/execution/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py index 4f8f01ae4..5c74905f5 100644 --- a/tests/execution/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -100,7 +100,7 @@ class TestCustomIsChanged: else: return False -class TestIsChangedWithConstants: +class TestIsChangedWithAllInputs: @classmethod def INPUT_TYPES(cls): return { @@ -120,10 +120,29 @@ class TestIsChangedWithConstants: @classmethod def IS_CHANGED(cls, image, value): - if image is None: - return value - else: - return image.mean().item() * value + # if image is None then an exception is thrown and is_changed becomes float("NaN") + return image.mean().item() * value + +class TestDontAlwaysRunDownstream: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "float": ("FLOAT",), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "always_run" + + CATEGORY = "Testing/Nodes" + + def always_run(self, float): + return (float,) + + @classmethod + def IS_CHANGED(cls, *args, **kwargs): + return float("NaN") class TestCustomValidation1: @classmethod @@ -486,7 +505,8 @@ TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, "TestCustomIsChanged": TestCustomIsChanged, - "TestIsChangedWithConstants": TestIsChangedWithConstants, + "TestIsChangedWithAllInputs": TestIsChangedWithAllInputs, + "TestDontAlwaysRunDownstream": TestDontAlwaysRunDownstream, "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, @@ -504,7 +524,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestLazyMixImages": "Lazy Mix Images", "TestVariadicAverage": "Variadic Average", "TestCustomIsChanged": "Custom IsChanged", - "TestIsChangedWithConstants": "IsChanged With Constants", + "TestIsChangedWithAllInputs": "IsChanged With All Inputs", + "TestDontAlwaysRunDownstream": "Dont Always Run Downstream", "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", From 5cf4115f50a790f6b01eb3544e88d39676b38db7 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Thu, 29 Jan 2026 01:22:33 -0500 Subject: [PATCH 08/13] Added "execution_cached" message back in --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 4337272b0..d2532552d 100644 --- a/execution.py +++ b/execution.py @@ -416,6 +416,7 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} + server.send_sync("execution_cached", { "nodes": [unique_id], "prompt_id": prompt_id}, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) if cached.ui is not None: ui_outputs[unique_id] = cached.ui From 96e9a81cdf493b1183dbb16ebc2c30d5fef85be9 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Thu, 29 Jan 2026 03:57:41 -0500 Subject: [PATCH 09/13] Fix not taking rawLink into account Forgot that input_data_all puts everything in a list. --- comfy_execution/caching.py | 8 +------- execution.py | 2 -- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index a579982a0..17217b1c6 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -103,15 +103,12 @@ class CacheKeySetInputSignature(CacheKeySet): self.updated_node_ids = set() self.node_sig_cache = {} - """Nodes' immediate node signatures.""" self.ancestry_cache = {} - """List of a node's ancestors.""" def include_node_id_in_input(self) -> bool: return False async def update_cache_key(self, node_id): - """Update key using cached outputs as part of the input signature.""" if node_id in self.updated_node_ids: return if node_id not in self.keys: @@ -123,7 +120,6 @@ class CacheKeySetInputSignature(CacheKeySet): return node_id in self.updated_node_ids async def add_keys(self, node_ids): - """Initialize keys.""" for node_id in node_ids: if node_id in self.keys: continue @@ -211,7 +207,7 @@ class CacheKeySetInputSignature(CacheKeySet): if is_link(inputs[key]): ancestor_id = inputs[key][0] hashable = get_hashable(input_data_all[key]) - if hashable is Unhashable or is_link(input_data_all[key]): + if hashable is Unhashable or is_link(input_data_all[key][0]): # Link still needed node_inputs[key] = inputs[key] if ancestor_id not in ancestors: @@ -281,11 +277,9 @@ class BasicCache: pass async def _update_cache_key_immediate(self, node_id): - """Update the cache key for the node.""" await self.cache_key_set.update_cache_key(node_id) def _is_key_updated_immediate(self, node_id): - """False if the cache key set is an updatable type and it hasn't been updated yet.""" return self.cache_key_set.is_key_updated(node_id) def _set_immediate(self, node_id, value): diff --git a/execution.py b/execution.py index d2532552d..1c4f851e5 100644 --- a/execution.py +++ b/execution.py @@ -536,7 +536,6 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS unblock() asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) - if len(output_ui) > 0: ui_outputs[unique_id] = { "meta": { @@ -549,7 +548,6 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS } if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) - if has_subgraph: cached_outputs = [] new_node_ids = [] From 3770dc0ec4065f57683ed4c85721d405e8cf27a1 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Thu, 29 Jan 2026 05:39:23 -0500 Subject: [PATCH 10/13] tweak test --- tests/execution/test_execution.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 10cb4216a..8b018614e 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -589,10 +589,13 @@ class TestExecution: assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50" + assert result2.did_run(float1), "Float node should always run" if server["should_cache_results"]: - assert not result2.did_run(output), "Output node should not have run the second time" + assert not result2.did_run(image1), "Image node should not have run again" + assert not result2.did_run(output), "Output node should not have run again" else: - assert result2.did_run(output), "Output node should always run here" + assert result2.did_run(image1), "Image node should have run again" + assert result2.did_run(output), "Output node should have run again" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): From c6b6128b2b79458326b21b6d0c4ed5fdb67f6cea Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Fri, 30 Jan 2026 13:21:02 -0500 Subject: [PATCH 11/13] Fix issue with subcache's cache --- comfy_execution/caching.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 17217b1c6..3a763718d 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -235,9 +235,14 @@ class BasicCache: self.cache = {} self.subcaches = {} + self.node_sig_cache = {} + self.ancestry_cache = {} + async def set_prompt(self, dynprompt, node_ids, is_changed): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed) + self.cache_key_set.node_sig_cache = self.node_sig_cache + self.cache_key_set.ancestry_cache = self.ancestry_cache await self.cache_key_set.add_keys(node_ids) self.is_changed = is_changed self.initialized = True @@ -270,6 +275,8 @@ class BasicCache: def clean_unused(self): assert self.initialized + self.node_sig_cache.clear() + self.ancestry_cache.clear() self._clean_cache() self._clean_subcaches() @@ -302,6 +309,8 @@ class BasicCache: subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) + subcache.node_sig_cache = self.node_sig_cache + subcache.ancestry_cache = self.ancestry_cache self.subcaches[subcache_key] = subcache await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed) return subcache @@ -414,6 +423,8 @@ class LRUCache(BasicCache): self._mark_used(node_id) def clean_unused(self): + self.node_sig_cache.clear() + self.ancestry_cache.clear() while len(self.cache) > self.max_size and self.min_generation < self.generation: self.min_generation += 1 to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] @@ -480,6 +491,8 @@ class RAMPressureCache(LRUCache): self.timestamps = {} def clean_unused(self): + self.node_sig_cache.clear() + self.ancestry_cache.clear() self._clean_subcaches() def set(self, node_id, value): From 5bad474118aaad5106f292c79d586a404dd320d9 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Fri, 30 Jan 2026 22:25:02 -0500 Subject: [PATCH 12/13] fix signature inconsistency --- comfy_execution/caching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 3a763718d..cfd40a928 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -221,7 +221,7 @@ class CacheKeySetInputSignature(CacheKeySet): logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}") node_inputs[key] = Unhashable() else: - node_inputs[key] = inputs[key] + node_inputs[key] = [inputs[key]] self.ancestry_cache[node_id] = ancestors return self.ancestry_cache[node_id], node_inputs @@ -427,7 +427,7 @@ class LRUCache(BasicCache): self.ancestry_cache.clear() while len(self.cache) > self.max_size and self.min_generation < self.generation: self.min_generation += 1 - to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + to_remove = [key for key in self.cache if key not in self.used_generation or self.used_generation[key] < self.min_generation] for key in to_remove: del self.cache[key] del self.used_generation[key] @@ -462,8 +462,8 @@ class LRUCache(BasicCache): return self async def update_cache_key(self, node_id): - self._mark_used(node_id) await self._update_cache_key_immediate(node_id) + self._mark_used(node_id) def is_key_updated(self, node_id): self._mark_used(node_id) From 90f57e6a8dba9decb27d389e6bc60057146bb1c8 Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Sat, 31 Jan 2026 07:06:41 -0500 Subject: [PATCH 13/13] Fix not cleaning subcaches --- comfy_execution/caching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index cfd40a928..06e489ac3 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -270,6 +270,8 @@ class BasicCache: for key in self.subcaches: if key not in preserve_subcaches: to_remove.append(key) + else: + self.subcaches[key].clean_unused() for key in to_remove: del self.subcaches[key]