This commit is contained in:
Deluxe233 2026-02-03 04:21:38 -05:00 committed by GitHub
commit 062183cba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 266 additions and 120 deletions

View File

@ -23,7 +23,7 @@ def include_unique_id_in_input(class_type: str) -> bool:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet(ABC): class CacheKeySet(ABC):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
self.keys = {} self.keys = {}
self.subcache_keys = {} self.subcache_keys = {}
@ -45,6 +45,12 @@ class CacheKeySet(ABC):
def get_subcache_key(self, node_id): def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None) return self.subcache_keys.get(node_id, None)
async def update_cache_key(self, node_id) -> None:
pass
def is_key_updated(self, node_id) -> bool:
return True
class Unhashable: class Unhashable:
def __init__(self): def __init__(self):
@ -62,10 +68,21 @@ def to_hashable(obj):
else: else:
# TODO - Support other objects like tensors? # TODO - Support other objects like tensors?
return Unhashable() return Unhashable()
def throw_on_unhashable(obj):
# Same as to_hashable except throwing for unhashables instead.
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(throw_on_unhashable(k), throw_on_unhashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [throw_on_unhashable(i) for i in obj]))
else:
raise Exception("Object unhashable.")
class CacheKeySetID(CacheKeySet): class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt = dynprompt self.dynprompt = dynprompt
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
@ -79,13 +96,25 @@ class CacheKeySetID(CacheKeySet):
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet): class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed)
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache self.is_changed = is_changed
self.updated_node_ids = set()
def include_node_id_in_input(self) -> bool: def include_node_id_in_input(self) -> bool:
return False return False
async def update_cache_key(self, node_id):
if node_id in self.updated_node_ids:
return
if node_id not in self.keys:
return
self.updated_node_ids.add(node_id)
self.keys[node_id] = await self.get_node_signature(node_id)
def is_key_updated(self, node_id):
return node_id in self.updated_node_ids
async def add_keys(self, node_ids): async def add_keys(self, node_ids):
for node_id in node_ids: for node_id in node_ids:
@ -94,28 +123,30 @@ class CacheKeySetInputSignature(CacheKeySet):
if not self.dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
continue continue
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
self.keys[node_id] = await self.get_node_signature(self.dynprompt, node_id) self.keys[node_id] = None
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
async def get_node_signature(self, dynprompt, node_id): async def get_node_signature(self, node_id):
signature = [] signatures = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id)
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) node = self.dynprompt.get_node(node_id)
node["signature"] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs))
signatures.append(node["signature"])
for ancestor_id in ancestors: for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) ancestor_node = self.dynprompt.get_node(ancestor_id)
return to_hashable(signature) assert "signature" in ancestor_node
signatures.append(ancestor_node["signature"])
signatures = frozenset(zip(itertools.count(), signatures))
return signatures
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): async def get_immediate_node_signature(self, node_id, ancestor_order_mapping, inputs):
if not dynprompt.has_node(node_id): if not self.dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it. # This node doesn't exist -- we can't cache it.
return [float("NaN")] return [float("NaN")]
node = dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, await self.is_changed_cache.get(node_id)] signature = [class_type, await self.is_changed.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):
if is_link(inputs[key]): if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key] (ancestor_id, ancestor_socket) = inputs[key]
@ -123,28 +154,69 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else: else:
signature.append((key, inputs[key])) signature.append((key, inputs[key]))
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
return signature return signature
def get_ordered_ancestry(self, node_id):
def get_ancestors(ancestors, ret: list=[]):
for ancestor_id in ancestors:
if ancestor_id not in ret:
ret.append(ancestor_id)
ancestor_node = self.dynprompt.get_node(ancestor_id)
get_ancestors(ancestor_node["ancestors"], ret)
return ret
ancestors, node_inputs = self.get_ordered_ancestry_internal(node_id)
ancestors = get_ancestors(ancestors)
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {} order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) for i, ancestor_id in enumerate(ancestors):
return ancestors, order_mapping order_mapping[ancestor_id] = i
return ancestors, order_mapping, node_inputs
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): def get_ordered_ancestry_internal(self, node_id):
if not dynprompt.has_node(node_id): def get_hashable(obj):
return try:
inputs = dynprompt.get_node(node_id)["inputs"] return throw_on_unhashable(obj)
input_keys = sorted(inputs.keys()) except:
for key in input_keys: return Unhashable
if is_link(inputs[key]):
ancestor_id = inputs[key][0] ancestors = []
if ancestor_id not in order_mapping: node_inputs = {}
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1 if not self.dynprompt.has_node(node_id):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) return ancestors, node_inputs
node = self.dynprompt.get_node(node_id)
if "ancestors" in node:
return node["ancestors"], node_inputs
input_data_all, _, _ = self.is_changed.get_input_data(node_id)
inputs = self.dynprompt.get_node(node_id)["inputs"]
for key in sorted(inputs.keys()):
if key in input_data_all:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
hashable = get_hashable(input_data_all[key])
if hashable is Unhashable or is_link(input_data_all[key][0]):
# Link still needed
node_inputs[key] = inputs[key]
if ancestor_id not in ancestors:
ancestors.append(ancestor_id)
else:
# Replace link
node_inputs[key] = input_data_all[key]
else:
hashable = get_hashable(inputs[key])
if hashable is Unhashable:
node_inputs[key] = Unhashable()
else:
node_inputs[key] = [inputs[key]]
node["ancestors"] = ancestors
return node["ancestors"], node_inputs
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class):
@ -155,11 +227,11 @@ class BasicCache:
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
await self.cache_key_set.add_keys(node_ids) await self.cache_key_set.add_keys(node_ids)
self.is_changed_cache = is_changed_cache self.is_changed = is_changed
self.initialized = True self.initialized = True
def all_node_ids(self): def all_node_ids(self):
@ -185,6 +257,8 @@ class BasicCache:
for key in self.subcaches: for key in self.subcaches:
if key not in preserve_subcaches: if key not in preserve_subcaches:
to_remove.append(key) to_remove.append(key)
else:
self.subcaches[key].clean_unused()
for key in to_remove: for key in to_remove:
del self.subcaches[key] del self.subcaches[key]
@ -196,16 +270,23 @@ class BasicCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
async def _update_cache_key_immediate(self, node_id):
await self.cache_key_set.update_cache_key(node_id)
def _is_key_updated_immediate(self, node_id):
return self.cache_key_set.is_key_updated(node_id)
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value if cache_key is not None:
self.cache[cache_key] = value
def _get_immediate(self, node_id): def _get_immediate(self, node_id):
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache: if cache_key is not None and cache_key in self.cache:
return self.cache[cache_key] return self.cache[cache_key]
else: else:
return None return None
@ -216,7 +297,7 @@ class BasicCache:
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed)
return subcache return subcache
def _get_subcache(self, node_id): def _get_subcache(self, node_id):
@ -272,10 +353,20 @@ class HierarchicalCache(BasicCache):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
assert cache is not None assert cache is not None
return await cache._ensure_subcache(node_id, children_ids) return await cache._ensure_subcache(node_id, children_ids)
async def update_cache_key(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
await cache._update_cache_key_immediate(node_id)
def is_key_updated(self, node_id):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._is_key_updated_immediate(node_id)
class NullCache: class NullCache:
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed):
pass pass
def all_node_ids(self): def all_node_ids(self):
@ -295,6 +386,12 @@ class NullCache:
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
return self return self
async def update_cache_key(self, node_id):
pass
def is_key_updated(self, node_id):
return True
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100):
@ -305,8 +402,8 @@ class LRUCache(BasicCache):
self.used_generation = {} self.used_generation = {}
self.children = {} self.children = {}
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed):
await super().set_prompt(dynprompt, node_ids, is_changed_cache) await super().set_prompt(dynprompt, node_ids, is_changed)
self.generation += 1 self.generation += 1
for node_id in node_ids: for node_id in node_ids:
self._mark_used(node_id) self._mark_used(node_id)
@ -314,7 +411,7 @@ class LRUCache(BasicCache):
def clean_unused(self): def clean_unused(self):
while len(self.cache) > self.max_size and self.min_generation < self.generation: while len(self.cache) > self.max_size and self.min_generation < self.generation:
self.min_generation += 1 self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] to_remove = [key for key in self.cache if key not in self.used_generation or self.used_generation[key] < self.min_generation]
for key in to_remove: for key in to_remove:
del self.cache[key] del self.cache[key]
del self.used_generation[key] del self.used_generation[key]
@ -347,6 +444,14 @@ class LRUCache(BasicCache):
self._mark_used(child_id) self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
async def update_cache_key(self, node_id):
await self._update_cache_key_immediate(node_id)
self._mark_used(node_id)
def is_key_updated(self, node_id):
self._mark_used(node_id)
return self._is_key_updated_immediate(node_id)
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure #Iterating the cache for usage analysis might be expensive, so if we trigger make sure

View File

@ -48,49 +48,40 @@ class ExecutionResult(Enum):
class DuplicateNodeError(Exception): class DuplicateNodeError(Exception):
pass pass
class IsChangedCache: class IsChanged:
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, execution_list: ExecutionList|None=None, extra_data: dict={}):
self.prompt_id = prompt_id self.prompt_id = prompt_id
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.outputs_cache = outputs_cache self.execution_list = execution_list
self.is_changed = {} self.extra_data = extra_data
async def get(self, node_id):
if node_id in self.is_changed:
return self.is_changed[node_id]
def get_input_data(self, node_id):
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_data(node["inputs"], class_def, node_id, self.execution_list, self.dynprompt, self.extra_data)
async def get(self, node_id):
node = self.dynprompt.get_node(node_id) node = self.dynprompt.get_node(node_id)
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
has_is_changed = False
is_changed_name = None is_changed_name = None
if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None:
has_is_changed = True
is_changed_name = "fingerprint_inputs" is_changed_name = "fingerprint_inputs"
elif hasattr(class_def, "IS_CHANGED"): elif hasattr(class_def, "IS_CHANGED"):
has_is_changed = True
is_changed_name = "IS_CHANGED" is_changed_name = "IS_CHANGED"
if not has_is_changed: if is_changed_name is None:
self.is_changed[node_id] = False return False
return self.is_changed[node_id]
if "is_changed" in node: input_data_all, _, v3_data = self.get_input_data(node_id)
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] is_changed = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except Exception as e:
logging.warning("WARNING: {}".format(e)) logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN") is_changed = float("NaN")
finally: return is_changed
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheEntry(NamedTuple): class CacheEntry(NamedTuple):
ui: dict ui: dict
@ -416,16 +407,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id) if caches.outputs.is_key_updated(unique_id):
if cached is not None: # Key is updated, the cache can be checked.
if server.client_id is not None: cached = caches.outputs.get(unique_id)
cached_ui = cached.ui or {} if cached is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) if server.client_id is not None:
if cached.ui is not None: cached_ui = cached.ui or {}
ui_outputs[unique_id] = cached.ui server.send_sync("execution_cached", { "nodes": [unique_id], "prompt_id": prompt_id}, server.client_id)
get_progress_state().finish_progress(unique_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
execution_list.cache_update(unique_id, cached) if cached.ui is not None:
return (ExecutionResult.SUCCESS, None, None) ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
try: try:
@ -466,11 +460,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
del pending_subgraph_results[unique_id] del pending_subgraph_results[unique_id]
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) if caches.outputs.is_key_updated(unique_id):
# The key is updated, the node is executing.
get_progress_state().start_progress(unique_id)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id) obj = caches.objects.get(unique_id)
if obj is None: if obj is None:
@ -496,6 +493,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i) execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if not caches.outputs.is_key_updated(unique_id):
# Update the cache key after any lazy inputs are evaluated.
async def update_cache_key(node_id, unblock):
await caches.outputs.update_cache_key(node_id)
unblock()
asyncio.create_task(update_cache_key(unique_id, execution_list.add_external_block(unique_id)))
return (ExecutionResult.PENDING, None, None)
def execution_block_cb(block): def execution_block_cb(block):
if block.message is not None: if block.message is not None:
mes = { mes = {
@ -577,8 +582,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cached_outputs.append((True, node_outputs)) cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids) new_node_ids = set(new_node_ids)
for cache in caches.all: for cache in caches.all:
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) await cache.ensure_subcache_for(unique_id, new_node_ids)
subcache.clean_unused()
for node_id in new_output_ids: for node_id in new_output_ids:
execution_list.add_node(node_id) execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id) execution_list.cache_link(node_id, unique_id)
@ -703,25 +707,16 @@ class PromptExecutor:
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server)) add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
for cache in self.caches.all: for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
cache.clean_unused()
cached_nodes = []
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {} ui_node_outputs = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs): for node_id in list(execute_outputs):
execution_list.add_node(node_id) execution_list.add_node(node_id)
@ -759,7 +754,9 @@ class PromptExecutor:
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
for cache in self.caches.all:
cache.clean_unused()
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):
unique_id = item unique_id = item

View File

@ -552,27 +552,50 @@ class TestExecution:
assert len(images1) == 1, "Should have 1 image" assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED` def test_is_changed_passed_cached_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) test_node = g.node("TestIsChangedWithAllInputs", image=input1.out(0), value=0.5)
output = g.node("PreviewImage", images=test_node.out(0)) output = g.node("PreviewImage", images=test_node.out(0))
result = client.run(g) result1 = client.run(g)
images = result.get_images(output) images = result1.get_images(output)
assert len(images) == 1, "Should have 1 image" assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
result = client.run(g) result2 = client.run(g)
images = result.get_images(output) images = result2.get_images(output)
assert len(images) == 1, "Should have 1 image" assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
if server["should_cache_results"]: if server["should_cache_results"]:
assert not result.did_run(test_node), "The execution should have been cached" assert not result2.did_run(test_node), "Test node should not have run again"
else: else:
assert result.did_run(test_node), "The execution should have been re-run" assert result2.did_run(test_node), "Test node should always run here"
def test_dont_always_run_downstream(self, client: ComfyClient, builder: GraphBuilder, server):
g = builder
float1 = g.node("TestDontAlwaysRunDownstream", float=0.5) # IS_CHANGED returns float("NaN")
image1 = g.node("StubConstantImage", value=float1.out(0), height=512, width=512, batch_size=1)
output = g.node("PreviewImage", images=image1.out(0))
result1 = client.run(g)
images = result1.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
result2 = client.run(g)
images = result2.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
assert result2.did_run(float1), "Float node should always run"
if server["should_cache_results"]:
assert not result2.did_run(image1), "Image node should not have run again"
assert not result2.did_run(output), "Output node should not have run again"
else:
assert result2.did_run(image1), "Image node should have run again"
assert result2.did_run(output), "Output node should have run again"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):

View File

@ -100,7 +100,7 @@ class TestCustomIsChanged:
else: else:
return False return False
class TestIsChangedWithConstants: class TestIsChangedWithAllInputs:
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
@ -120,10 +120,29 @@ class TestIsChangedWithConstants:
@classmethod @classmethod
def IS_CHANGED(cls, image, value): def IS_CHANGED(cls, image, value):
if image is None: # if image is None then an exception is thrown and is_changed becomes float("NaN")
return value return image.mean().item() * value
else:
return image.mean().item() * value class TestDontAlwaysRunDownstream:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"float": ("FLOAT",),
},
}
RETURN_TYPES = ("FLOAT",)
FUNCTION = "always_run"
CATEGORY = "Testing/Nodes"
def always_run(self, float):
return (float,)
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return float("NaN")
class TestCustomValidation1: class TestCustomValidation1:
@classmethod @classmethod
@ -486,7 +505,8 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages, "TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage, "TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged, "TestCustomIsChanged": TestCustomIsChanged,
"TestIsChangedWithConstants": TestIsChangedWithConstants, "TestIsChangedWithAllInputs": TestIsChangedWithAllInputs,
"TestDontAlwaysRunDownstream": TestDontAlwaysRunDownstream,
"TestCustomValidation1": TestCustomValidation1, "TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2, "TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3, "TestCustomValidation3": TestCustomValidation3,
@ -504,7 +524,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images", "TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average", "TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged", "TestCustomIsChanged": "Custom IsChanged",
"TestIsChangedWithConstants": "IsChanged With Constants", "TestIsChangedWithAllInputs": "IsChanged With All Inputs",
"TestDontAlwaysRunDownstream": "Dont Always Run Downstream",
"TestCustomValidation1": "Custom Validation 1", "TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2", "TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3", "TestCustomValidation3": "Custom Validation 3",