This commit is contained in:
Deluxe233 2026-01-31 00:37:25 +02:00 committed by GitHub
commit a7ea4dc7b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 299 additions and 130 deletions

View File

@ -3,6 +3,7 @@ import gc
import itertools import itertools
import psutil import psutil
import time import time
import logging
import torch import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
@ -14,7 +15,6 @@ from comfy_execution.graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool: def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
@ -23,7 +23,7 @@ def include_unique_id_in_input(class_type: str) -> bool:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet(ABC): class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
self.keys = {} self.keys = {}
self.subcache_keys = {} self.subcache_keys = {}
@ -46,6 +46,12 @@ class CacheKeySet(ABC):
def get_subcache_key(self, node_id): def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None) return self.subcache_keys.get(node_id, None)
async def update_cache_key(self, node_id) -> None:
pass
def is_key_updated(self, node_id) -> bool:
return True
class Unhashable: class Unhashable:
def __init__(self): def __init__(self):
self.value = float("NaN") self.value = float("NaN")
@ -63,9 +69,20 @@ def to_hashable(obj):
# TODO - Support other objects like tensors? # TODO - Support other objects like tensors?
return Unhashable() return Unhashable()
def throw_on_unhashable(obj):
# Same as to_hashable except throwing for unhashables instead.
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(throw_on_unhashable(k), throw_on_unhashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [throw_on_unhashable(i) for i in obj]))
else:
raise Exception("Object unhashable.")
class CacheKeySetID(CacheKeySet): class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt = dynprompt self.dynprompt = dynprompt
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
@ -79,14 +96,29 @@ class CacheKeySetID(CacheKeySet):
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet): class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt = dynprompt self.dynprompt: DynamicPrompt = dynprompt
self.is_changed_cache = is_changed_cache self.is_changed = is_changed
self.updated_node_ids = set()
self.node_sig_cache = {}
self.ancestry_cache = {}
def include_node_id_in_input(self) -> bool: def include_node_id_in_input(self) -> bool:
return False return False
async def update_cache_key(self, node_id):
if node_id in self.updated_node_ids:
return
if node_id not in self.keys:
return
self.updated_node_ids.add(node_id)
self.keys[node_id] = await self.get_node_signature(node_id)
def is_key_updated(self, node_id):
return node_id in self.updated_node_ids
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
if node_id in self.keys: if node_id in self.keys:
@ -94,57 +126,105 @@ class CacheKeySetInputSignature(CacheKeySet):
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
continue continue
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id) self.keys[node_id] = None
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
async def get_node_signature(self, dynprompt, node_id): async def get_node_signature(self, node_id):
signature = [] signatures = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id)
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs))
for ancestor_id in ancestors: signatures.append(self.node_sig_cache[node_id])
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): for ancestor_id in ancestors:
if not dynprompt.has_node(node_id): assert ancestor_id in self.node_sig_cache
signatures.append(self.node_sig_cache[ancestor_id])
signatures = frozenset(zip(itertools.count(), signatures))
logging.debug(f"signature for node {node_id}: {signatures}")
return signatures
async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict):
if not self.dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it. # This node doesn't exist -- we can't cache it.
return [float("NaN")] return [float("NaN")]
node = dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, await self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature = [class_type, await self.is_changed.get(node_id)]
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):
if is_link(inputs[key]): input = inputs[key]
(ancestor_id, ancestor_socket) = inputs[key] if is_link(input):
(ancestor_id, ancestor_socket) = input
ancestor_index = ancestor_order_mapping[ancestor_id] ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else: else:
signature.append((key, inputs[key])) signature.append((key, input))
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
return signature return signature
# This function returns a list of all ancestors of the given node. The order of the list is def get_ordered_ancestry(self, node_id):
# deterministic based on which specific inputs the ancestor is connected by. def get_ancestors(ancestors, ret: list=[]):
def get_ordered_ancestry(self, dynprompt, node_id): for ancestor_id in ancestors:
ancestors = [] if ancestor_id not in ret:
order_mapping = {} ret.append(ancestor_id)
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) get_ancestors(self.ancestry_cache[ancestor_id], ret)
return ancestors, order_mapping return ret
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id)
if not dynprompt.has_node(node_id): ancestors = get_ancestors(ancestors)
return
inputs = dynprompt.get_node(node_id)["inputs"] order_mapping = {}
input_keys = sorted(inputs.keys()) for i, ancestor_id in enumerate(ancestors):
for key in input_keys: order_mapping[ancestor_id] = i
if is_link(inputs[key]):
ancestor_id = inputs[key][0] return ancestors, order_mapping, node_inputs
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id) def get_ordered_ancestry_internal(self, node_id):
order_mapping[ancestor_id] = len(ancestors) - 1 def get_hashable(obj):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) try:
return throw_on_unhashable(obj)
except:
return Unhashable
ancestors = []
node_inputs = {}
if node_id in self.ancestry_cache:
return self.ancestry_cache[node_id], node_inputs
if not self.dynprompt.has_node(node_id):
return ancestors, node_inputs
input_data_all, _, _ = self.is_changed.get_input_data(node_id)
inputs = self.dynprompt.get_node(node_id)["inputs"]
for key in sorted(inputs.keys()):
if key in input_data_all:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
hashable = get_hashable(input_data_all[key])
if hashable is Unhashable or is_link(input_data_all[key][0]):
# Link still needed
node_inputs[key] = inputs[key]
if ancestor_id not in ancestors:
ancestors.append(ancestor_id)
else:
# Replace link
node_inputs[key] = input_data_all[key]
else:
hashable = get_hashable(inputs[key])
if hashable is Unhashable:
logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}")
node_inputs[key] = Unhashable()
else:
node_inputs[key] = inputs[key]
self.ancestry_cache[node_id] = ancestors
return self.ancestry_cache[node_id], node_inputs
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class):
@ -155,11 +235,16 @@ class BasicCache:
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.node_sig_cache = {}
self.ancestry_cache = {}
async def set_prompt(self, dynprompt, node_ids, is_changed):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
self.cache_key_set.node_sig_cache = self.node_sig_cache
self.cache_key_set.ancestry_cache = self.ancestry_cache
await self.cache_key_set.add_keys(node_ids) await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache self.is_changed = is_changed
self.initialized = True self.initialized = True
def all_node_ids(self): def all_node_ids(self):
@ -190,22 +275,31 @@ class BasicCache:
def clean_unused(self): def clean_unused(self):
assert self.initialized assert self.initialized
self.node_sig_cache.clear()
self.ancestry_cache.clear()
self._clean_cache() self._clean_cache()
self._clean_subcaches() self._clean_subcaches()
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
async def _update_cache_key_immediate(self, node_id):
await self.cache_key_set.update_cache_key(node_id)
def _is_key_updated_immediate(self, node_id):
return self.cache_key_set.is_key_updated(node_id)
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value if cache_key is not None:
self.cache[cache_key] = value
def _get_immediate(self, node_id): def _get_immediate(self, node_id):
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache: if cache_key is not None and cache_key in self.cache:
return self.cache[cache_key] return self.cache[cache_key]
else: else:
return None return None
@ -215,8 +309,10 @@ class BasicCache:
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
subcache.node_sig_cache = self.node_sig_cache
subcache.ancestry_cache = self.ancestry_cache
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed)
return subcache return subcache
def _get_subcache(self, node_id): def _get_subcache(self, node_id):
@ -273,9 +369,18 @@ class HierarchicalCache(BasicCache):
assert cache is not None assert cache is not None
return await cache._ensure_subcache(node_id, children_ids) return await cache._ensure_subcache(node_id, children_ids)
class NullCache: async def update_cache_key(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
await cache._update_cache_key_immediate(node_id)
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): def is_key_updated(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._is_key_updated_immediate(node_id)
class NullCache:
async def set_prompt(self, dynprompt, node_ids, is_changed):
pass pass
def all_node_ids(self): def all_node_ids(self):
@ -296,6 +401,12 @@ class NullCache:
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
return self return self
async def update_cache_key(self, node_id):
pass
def is_key_updated(self, node_id):
return True
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100):
super().__init__(key_class) super().__init__(key_class)
@ -305,13 +416,15 @@ class LRUCache(BasicCache):
self.used_generation = {} self.used_generation = {}
self.children = {} self.children = {}
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed):
await super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed)
self.generation += 1 self.generation += 1
for node_id in node_ids: for node_id in node_ids:
self._mark_used(node_id) self._mark_used(node_id)
def clean_unused(self): def clean_unused(self):
self.node_sig_cache.clear()
self.ancestry_cache.clear()
while len(self.cache) > self.max_size and self.min_generation < self.generation: while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1 self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
@ -348,6 +461,14 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
async def update_cache_key(self, node_id):
self._mark_used(node_id)
await self._update_cache_key_immediate(node_id)
def is_key_updated(self, node_id):
self._mark_used(node_id)
return self._is_key_updated_immediate(node_id)
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure #Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. #to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
@ -365,12 +486,13 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache): class RAMPressureCache(LRUCache):
def __init__(self, key_class): def __init__(self, key_class):
super().__init__(key_class, 0) super().__init__(key_class, 0)
self.timestamps = {} self.timestamps = {}
def clean_unused(self): def clean_unused(self):
self.node_sig_cache.clear()
self.ancestry_cache.clear()
self._clean_subcaches() self._clean_subcaches()
def set(self, node_id, value): def set(self, node_id, value):

View File

@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io from comfy_api.latest import io, _io
from server import PromptServer
class ExecutionResult(Enum): class ExecutionResult(Enum):
SUCCESS = 0 SUCCESS = 0
@ -46,49 +46,40 @@ class ExecutionResult(Enum):
class DuplicateNodeError(Exception): class DuplicateNodeError(Exception):
pass pass
class IsChangedCache: class IsChanged:
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, execution_list: ExecutionList|None=None, extra_data: dict={}):
self.prompt_id = prompt_id self.prompt_id = prompt_id
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.outputs_cache = outputs_cache self.execution_list = execution_list
self.is_changed = {} self.extra_data = extra_data
async def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
def get_input_data(self, node_id):
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_data(node["inputs"], class_def, node_id, self.execution_list, self.dynprompt, self.extra_data)
async def get(self, node_id):
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
has_is_changed = False
is_changed_name = None is_changed_name = None
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
has_is_changed = True
is_changed_name = "fingerprint_inputs" is_changed_name = "fingerprint_inputs"
elif hasattr(class_def, "IS_CHANGED"): elif hasattr(class_def, "IS_CHANGED"):
has_is_changed = True
is_changed_name = "IS_CHANGED" is_changed_name = "IS_CHANGED"
if not has_is_changed: if is_changed_name is None:
self.is_changed[node_id] = False return False
return self.is_changed[node_id]
if "is_changed" in node: input_data_all, _, v3_data = self.get_input_data(node_id)
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] is_changed = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
logging.warning("WARNING: {}".format(e)) logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN") is_changed = float("NaN")
finally: return is_changed
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheEntry(NamedTuple): class CacheEntry(NamedTuple):
ui: dict ui: dict
@ -118,7 +109,7 @@ class CacheSet:
else: else:
self.init_classic_cache() self.init_classic_cache()
self.all = [self.outputs, self.objects] self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
@ -406,7 +397,10 @@ def format_value(x):
else: else:
return str(x) return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet,
current_item: str, extra_data: dict, executed: set, prompt_id: str,
execution_list: ExecutionList, pending_subgraph_results: dict,
pending_async_nodes: dict, ui_outputs: dict):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -414,16 +408,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id)
if cached is not None: if caches.outputs.is_key_updated(unique_id):
if server.client_id is not None: # Key is updated, the cache can be checked.
cached_ui = cached.ui or {} cached = caches.outputs.get(unique_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) logging.debug(f"execute: {unique_id} cached: {cached is not None}")
if cached.ui is not None: if cached is not None:
ui_outputs[unique_id] = cached.ui if server.client_id is not None:
get_progress_state().finish_progress(unique_id) cached_ui = cached.ui or {}
execution_list.cache_update(unique_id, cached) server.send_sync("execution_cached", { "nodes": [unique_id], "prompt_id": prompt_id}, server.client_id)
return (ExecutionResult.SUCCESS, None, None) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
try: try:
@ -464,11 +463,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
del pending_subgraph_results[unique_id] del pending_subgraph_results[unique_id]
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) if caches.outputs.is_key_updated(unique_id):
# The key is updated, the node is executing.
get_progress_state().start_progress(unique_id)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id) obj = caches.objects.get(unique_id)
if obj is None: if obj is None:
@ -479,6 +481,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
# for check_lazy_status, the returned data should include the original key of the input # for check_lazy_status, the returned data should include the original key of the input
v3_data_lazy = v3_data.copy() v3_data_lazy = v3_data.copy()
@ -494,6 +497,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i) execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if not caches.outputs.is_key_updated(unique_id):
# Update the cache key after any lazy inputs are evaluated.
async def update_cache_key(node_id, unblock):
await caches.outputs.update_cache_key(node_id)
unblock()
asyncio.create_task(update_cache_key(unique_id, execution_list.add_external_block(unique_id)))
return (ExecutionResult.PENDING, None, None)
def execution_block_cb(block): def execution_block_cb(block):
if block.message is not None: if block.message is not None:
mes = { mes = {
@ -563,8 +574,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cached_outputs.append((True, node_outputs)) cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids) new_node_ids = set(new_node_ids)
for cache in caches.all: for cache in caches.all:
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused()
for node_id in new_output_ids: for node_id in new_output_ids:
execution_list.add_node(node_id) execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id) execution_list.cache_link(node_id, unique_id)
@ -689,25 +699,16 @@ class PromptExecutor:
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server)) add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
for cache in self.caches.all: for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
cache.clean_unused()
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {} ui_node_outputs = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs): for node_id in list(execute_outputs):
execution_list.add_node(node_id) execution_list.add_node(node_id)
@ -746,6 +747,8 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
for cache in self.caches.all:
cache.clean_unused()
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item unique_id = item

View File

@ -552,27 +552,50 @@ class TestExecution:
assert len(images1) == 1, "Should have 1 image" assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED` def test_is_changed_passed_cached_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) test_node = g.node("TestIsChangedWithAllInputs", image=input1.out(0), value=0.5)
output = g.node("PreviewImage", images=test_node.out(0)) output = g.node("PreviewImage", images=test_node.out(0))
result = client.run(g) result1 = client.run(g)
images = result.get_images(output) images = result1.get_images(output)
assert len(images) == 1, "Should have 1 image" assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
result = client.run(g) result2 = client.run(g)
images = result.get_images(output) images = result2.get_images(output)
assert len(images) == 1, "Should have 1 image" assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
if server["should_cache_results"]: if server["should_cache_results"]:
assert not result.did_run(test_node), "The execution should have been cached" assert not result2.did_run(test_node), "Test node should not have run again"
else: else:
assert result.did_run(test_node), "The execution should have been re-run" assert result2.did_run(test_node), "Test node should always run here"
def test_dont_always_run_downstream(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
float1 = g.node("TestDontAlwaysRunDownstream", float=0.5) # IS_CHANGED returns float("NaN")
image1 = g.node("StubConstantImage", value=float1.out(0), height=512, width=512, batch_size=1)
output = g.node("PreviewImage", images=image1.out(0))
result1 = client.run(g)
images = result1.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
result2 = client.run(g)
images = result2.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
assert result2.did_run(float1), "Float node should always run"
if server["should_cache_results"]:
assert not result2.did_run(image1), "Image node should not have run again"
assert not result2.did_run(output), "Output node should not have run again"
else:
assert result2.did_run(image1), "Image node should have run again"
assert result2.did_run(output), "Output node should have run again"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):

View File

@ -100,7 +100,7 @@ class TestCustomIsChanged:
else: else:
return False return False
class TestIsChangedWithConstants: class TestIsChangedWithAllInputs:
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
@ -120,10 +120,29 @@ class TestIsChangedWithConstants:
@classmethod @classmethod
def IS_CHANGED(cls, image, value): def IS_CHANGED(cls, image, value):
if image is None: # if image is None then an exception is thrown and is_changed becomes float("NaN")
return value return image.mean().item() * value
else:
return image.mean().item() * value class TestDontAlwaysRunDownstream:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"float": ("FLOAT",),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "always_run"
CATEGORY = "Testing/Nodes"
def always_run(self, float):
return (float,)
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return float("NaN")
class TestCustomValidation1: class TestCustomValidation1:
@classmethod @classmethod
@ -486,7 +505,8 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged, "TestCustomIsChanged": TestCustomIsChanged,
"TestIsChangedWithConstants": TestIsChangedWithConstants, "TestIsChangedWithAllInputs": TestIsChangedWithAllInputs,
"TestDontAlwaysRunDownstream": TestDontAlwaysRunDownstream,
"TestCustomValidation1": TestCustomValidation1, "TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2, "TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3, "TestCustomValidation3": TestCustomValidation3,
@ -504,7 +524,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images", "TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average", "TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged", "TestCustomIsChanged": "Custom IsChanged",
"TestIsChangedWithConstants": "IsChanged With Constants", "TestIsChangedWithAllInputs": "IsChanged With All Inputs",
"TestDontAlwaysRunDownstream": "Dont Always Run Downstream",
"TestCustomValidation1": "Custom Validation 1", "TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2", "TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3", "TestCustomValidation3": "Custom Validation 3",