Compare commits

...

2 Commits

Author SHA1 Message Date
Deluxe233
2db3ca609d Simplified storing signatures/ancestors 2026-02-03 03:01:52 -05:00
Deluxe233
982092f79a Removed unnecessary changes 2026-02-03 02:11:48 -05:00
2 changed files with 25 additions and 50 deletions

View File

@ -3,7 +3,6 @@ import gc
import itertools
import psutil
import time
import logging
import torch
from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt
@ -15,6 +14,7 @@ from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
@ -98,12 +98,9 @@ class CacheKeySetID(CacheKeySet):
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt: DynamicPrompt = dynprompt
self.dynprompt = dynprompt
self.is_changed = is_changed
self.updated_node_ids = set()
self.node_sig_cache = {}
self.ancestry_cache = {}
def include_node_id_in_input(self) -> bool:
return False
@ -132,39 +129,33 @@ class CacheKeySetInputSignature(CacheKeySet):
async def get_node_signature(self, node_id):
signatures = []
ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id)
self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs))
signatures.append(self.node_sig_cache[node_id])
node = self.dynprompt.get_node(node_id)
node["signature"] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs))
signatures.append(node["signature"])
for ancestor_id in ancestors:
assert ancestor_id in self.node_sig_cache
signatures.append(self.node_sig_cache[ancestor_id])
ancestor_node = self.dynprompt.get_node(ancestor_id)
assert "signature" in ancestor_node
signatures.append(ancestor_node["signature"])
signatures = frozenset(zip(itertools.count(), signatures))
logging.debug(f"signature for node {node_id}: {signatures}")
return signatures
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict):
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping, inputs):
if not self.dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, await self.is_changed.get(node_id)]
for key in sorted(inputs.keys()):
input = inputs[key]
if is_link(input):
(ancestor_id, ancestor_socket) = input
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, input))
signature.append((key, inputs[key]))
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
return signature
def get_ordered_ancestry(self, node_id):
@ -172,7 +163,8 @@ class CacheKeySetInputSignature(CacheKeySet):
for ancestor_id in ancestors:
if ancestor_id not in ret:
ret.append(ancestor_id)
get_ancestors(self.ancestry_cache[ancestor_id], ret)
ancestor_node = self.dynprompt.get_node(ancestor_id)
get_ancestors(ancestor_node["ancestors"], ret)
return ret
ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id)
@ -194,12 +186,13 @@ class CacheKeySetInputSignature(CacheKeySet):
ancestors = []
node_inputs = {}
if node_id in self.ancestry_cache:
return self.ancestry_cache[node_id], node_inputs
if not self.dynprompt.has_node(node_id):
return ancestors, node_inputs
node = self.dynprompt.get_node(node_id)
if "ancestors" in node:
return node["ancestors"], node_inputs
input_data_all, _, _ = self.is_changed.get_input_data(node_id)
inputs = self.dynprompt.get_node(node_id)["inputs"]
for key in sorted(inputs.keys()):
@ -218,13 +211,12 @@ class CacheKeySetInputSignature(CacheKeySet):
else:
hashable = get_hashable(inputs[key])
if hashable is Unhashable:
logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}")
node_inputs[key] = Unhashable()
else:
node_inputs[key] = [inputs[key]]
self.ancestry_cache[node_id] = ancestors
return self.ancestry_cache[node_id], node_inputs
node["ancestors"] = ancestors
return node["ancestors"], node_inputs
class BasicCache:
def __init__(self, key_class):
@ -235,14 +227,9 @@ class BasicCache:
self.cache = {}
self.subcaches = {}
self.node_sig_cache = {}
self.ancestry_cache = {}
async def set_prompt(self, dynprompt, node_ids, is_changed):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
self.cache_key_set.node_sig_cache = self.node_sig_cache
self.cache_key_set.ancestry_cache = self.ancestry_cache
await self.cache_key_set.add_keys(node_ids)
self.is_changed = is_changed
self.initialized = True
@ -277,8 +264,6 @@ class BasicCache:
def clean_unused(self):
assert self.initialized
self.node_sig_cache.clear()
self.ancestry_cache.clear()
self._clean_cache()
self._clean_subcaches()
@ -311,8 +296,6 @@ class BasicCache:
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
subcache.node_sig_cache = self.node_sig_cache
subcache.ancestry_cache = self.ancestry_cache
self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed)
return subcache
@ -382,6 +365,7 @@ class HierarchicalCache(BasicCache):
return cache._is_key_updated_immediate(node_id)
class NullCache:
async def set_prompt(self, dynprompt, node_ids, is_changed):
pass
@ -425,8 +409,6 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
def clean_unused(self):
self.node_sig_cache.clear()
self.ancestry_cache.clear()
while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1
to_remove = [key for key in self.cache if key not in self.used_generation or self.used_generation[key] < self.min_generation]
@ -488,13 +470,12 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self.node_sig_cache.clear()
self.ancestry_cache.clear()
self._clean_subcaches()
def set(self, node_id, value):

View File

@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
from server import PromptServer
class ExecutionResult(Enum):
SUCCESS = 0
@ -109,7 +109,7 @@ class CacheSet:
else:
self.init_classic_cache()
self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects]
self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
@ -397,10 +397,7 @@ def format_value(x):
else:
return str(x)
async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet,
current_item: str, extra_data: dict, executed: set, prompt_id: str,
execution_list: ExecutionList, pending_subgraph_results: dict,
pending_async_nodes: dict, ui_outputs: dict):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -408,11 +405,9 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.is_key_updated(unique_id):
# Key is updated, the cache can be checked.
cached = caches.outputs.get(unique_id)
logging.debug(f"execute: {unique_id} cached: {cached is not None}")
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@ -481,7 +476,6 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
# for check_lazy_status, the returned data should include the original key of the input
v3_data_lazy = v3_data.copy()