diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 06e489ac3..b0591c08e 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -3,7 +3,6 @@ import gc import itertools import psutil import time -import logging import torch from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt @@ -15,6 +14,7 @@ from comfy_execution.graph_utils import is_link NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} + def include_unique_id_in_input(class_type: str) -> bool: if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] @@ -98,9 +98,8 @@ class CacheKeySetID(CacheKeySet): class CacheKeySetInputSignature(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) - self.dynprompt: DynamicPrompt = dynprompt + self.dynprompt = dynprompt self.is_changed = is_changed - self.updated_node_ids = set() self.node_sig_cache = {} self.ancestry_cache = {} @@ -134,37 +133,29 @@ class CacheKeySetInputSignature(CacheKeySet): ancestors, order_mapping, node_inputs = self.get_ordered_ancestry(node_id) self.node_sig_cache[node_id] = to_hashable(await self.get_immediate_node_signature(node_id, order_mapping, node_inputs)) signatures.append(self.node_sig_cache[node_id]) - for ancestor_id in ancestors: assert ancestor_id in self.node_sig_cache signatures.append(self.node_sig_cache[ancestor_id]) - signatures = frozenset(zip(itertools.count(), signatures)) - logging.debug(f"signature for node {node_id}: {signatures}") return signatures - async def get_immediate_node_signature(self, node_id, ancestor_order_mapping: dict, inputs: dict): + async def get_immediate_node_signature(self, node_id, ancestor_order_mapping, inputs): if not self.dynprompt.has_node(node_id): # This node doesn't exist -- we can't cache it. return [float("NaN")] node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - signature = [class_type, await self.is_changed.get(node_id)] - for key in sorted(inputs.keys()): - input = inputs[key] - if is_link(input): - (ancestor_id, ancestor_socket) = input + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] ancestor_index = ancestor_order_mapping[ancestor_id] signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) else: - signature.append((key, input)) - + signature.append((key, inputs[key])) if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): signature.append(node_id) - return signature def get_ordered_ancestry(self, node_id): @@ -218,7 +209,6 @@ class CacheKeySetInputSignature(CacheKeySet): else: hashable = get_hashable(inputs[key]) if hashable is Unhashable: - logging.warning(f"Node {node_id} cannot be cached due to whatever this thing is: {inputs[key]}") node_inputs[key] = Unhashable() else: node_inputs[key] = [inputs[key]] @@ -234,7 +224,6 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} - self.node_sig_cache = {} self.ancestry_cache = {} @@ -382,6 +371,7 @@ class HierarchicalCache(BasicCache): return cache._is_key_updated_immediate(node_id) class NullCache: + async def set_prompt(self, dynprompt, node_ids, is_changed): pass @@ -488,6 +478,7 @@ RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): + def __init__(self, key_class): super().__init__(key_class, 0) self.timestamps = {} diff --git a/execution.py b/execution.py index 1c4f851e5..88b6d32cf 100644 --- a/execution.py +++ b/execution.py @@ -36,7 +36,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io -from server import PromptServer + class ExecutionResult(Enum): SUCCESS = 0 @@ -109,7 +109,7 @@ class CacheSet: else: self.init_classic_cache() - self.all: list[BasicCache, BasicCache] = [self.outputs, self.objects] + self.all = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): @@ -397,10 +397,7 @@ def format_value(x): else: return str(x) -async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheSet, - current_item: str, extra_data: dict, executed: set, prompt_id: str, - execution_list: ExecutionList, pending_subgraph_results: dict, - pending_async_nodes: dict, ui_outputs: dict): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -408,11 +405,9 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if caches.outputs.is_key_updated(unique_id): # Key is updated, the cache can be checked. cached = caches.outputs.get(unique_id) - logging.debug(f"execute: {unique_id} cached: {cached is not None}") if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -481,7 +476,6 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None - if lazy_status_present: # for check_lazy_status, the returned data should include the original key of the input v3_data_lazy = v3_data.copy()