diff --git a/comfy/caching.py b/comfy/caching.py new file mode 100644 index 000000000..ef047dcc5 --- /dev/null +++ b/comfy/caching.py @@ -0,0 +1,316 @@ +import itertools +from typing import Sequence, Mapping + +import nodes + +from comfy.graph_utils import is_link + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + +class Unhashable: + def __init__(self): + self.value = float("NaN") + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self): + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + +class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + + def include_node_id_in_input(self): + return True + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.dynprompt = None + self.cache_key_set = None + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + + def all_node_ids(self): + assert self.cache_key_set is not None + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def clean_unused(self): + assert self.cache_key_set is not None + preserve_keys = set(self.cache_key_set.get_used_keys()) + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def _set_immediate(self, node_id, value): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + + def all_active_values(self): + active_nodes = self.all_node_ids() + result = [] + for node_id in active_nodes: + value = self.get(node_id) + if value is not None: + result.append(value) + return result + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + print("LRUCache: Now at generation %d" % self.generation) + + def clean_unused(self): + print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size)) + while len(self.cache) > self.max_size and self.min_generation < self.generation: + print("LRUCache: Evicting generation %d" % self.min_generation) + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self + + def all_active_values(self): + explored = set() + to_explore = set(self.cache_key_set.get_used_keys()) + while len(to_explore) > 0: + cache_key = to_explore.pop() + if cache_key not in explored: + self.used_generation[cache_key] = self.generation + explored.add(cache_key) + if cache_key in self.children: + to_explore.update(self.children[cache_key]) + return [self.cache[key] for key in explored if key in self.cache] + diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fda245433..fb51e1219 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -69,6 +69,10 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) +cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") +cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/graph.py b/comfy/graph.py new file mode 100644 index 000000000..a16774309 --- /dev/null +++ b/comfy/graph.py @@ -0,0 +1,166 @@ +import nodes + +from comfy.graph_utils import is_link + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + return None + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) + value = inputs[to_input] + if not is_link(value): + raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = "lazy" in input_info and input_info["lazy"] + if include_lazy or not is_lazy: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, +# it can still be returned to the graph after having further dependencies added. +class ExecutionList(TopologicalSort): + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if self.output_cache.get(from_node_id) is not None: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None + available = self.get_ready_nodes() + if len(available) == 0: + raise Exception("Dependency cycle detected") + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None diff --git a/execution.py b/execution.py index 1f7958f44..d9289d3c4 100644 --- a/execution.py +++ b/execution.py @@ -6,7 +6,6 @@ import threading import heapq import traceback import gc -import time from enum import Enum import torch @@ -14,172 +13,71 @@ import nodes import comfy.model_management import comfy.graph_utils +from comfy.graph import get_input_info, ExecutionList, DynamicPrompt from comfy.graph_utils import is_link, ExecutionBlocker, GraphBuilder +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID class ExecutionResult(Enum): SUCCESS = 0 FAILURE = 1 SLEEPING = 2 -def get_input_info(class_def, input_name): - valid_inputs = class_def.INPUT_TYPES() - input_info = None - input_category = None - if "required" in valid_inputs and input_name in valid_inputs["required"]: - input_category = "required" - input_info = valid_inputs["required"][input_name] - elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: - input_category = "optional" - input_info = valid_inputs["optional"][input_name] - elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: - input_category = "hidden" - input_info = valid_inputs["hidden"][input_name] - if input_info is None: - return None, None, None - input_type = input_info[0] - if len(input_info) > 1: - extra_info = input_info[1] - else: - extra_info = {} - return input_type, input_category, extra_info - -# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, -# it can still be returned to the graph after having further dependencies added. -class TopologicalSort: - def __init__(self, dynprompt): +class IsChangedCache: + def __init__(self, dynprompt, outputs_cache): self.dynprompt = dynprompt - self.pendingNodes = {} - self.blockCount = {} # Number of nodes this node is directly blocked by - self.blocking = {} # Which nodes are blocked by this node + self.outputs_cache = outputs_cache + self.is_changed = {} - def get_input_info(self, unique_id, input_name): - class_type = self.dynprompt.get_node(unique_id)["class_type"] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return get_input_info(class_def, input_name) - - def make_input_strong_link(self, to_node_id, to_input): - inputs = self.dynprompt.get_node(to_node_id)["inputs"] - if to_input not in inputs: - raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) - value = inputs[to_input] - if not is_link(value): - raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) - from_node_id, from_socket = value - self.add_strong_link(from_node_id, from_socket, to_node_id) - - def add_strong_link(self, from_node_id, from_socket, to_node_id): - self.add_node(from_node_id) - if to_node_id not in self.blocking[from_node_id]: - self.blocking[from_node_id][to_node_id] = {} - self.blockCount[to_node_id] += 1 - self.blocking[from_node_id][to_node_id][from_socket] = True - - def add_node(self, unique_id): - if unique_id in self.pendingNodes: - return - self.pendingNodes[unique_id] = True - self.blockCount[unique_id] = 0 - self.blocking[unique_id] = {} - - inputs = self.dynprompt.get_node(unique_id)["inputs"] - for input_name in inputs: - value = inputs[input_name] - if is_link(value): - from_node_id, from_socket = value - input_type, input_category, input_info = self.get_input_info(unique_id, input_name) - if "lazy" not in input_info or not input_info["lazy"]: - self.add_strong_link(from_node_id, from_socket, unique_id) - - def get_ready_nodes(self): - return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] - - def pop_node(self, unique_id): - del self.pendingNodes[unique_id] - for blocked_node_id in self.blocking[unique_id]: - self.blockCount[blocked_node_id] -= 1 - del self.blocking[unique_id] - - def is_empty(self): - return len(self.pendingNodes) == 0 - -class ExecutionList(TopologicalSort): - def __init__(self, dynprompt, outputs): - super().__init__(dynprompt) - self.outputs = outputs - self.staged_node_id = None - - def add_strong_link(self, from_node_id, from_socket, to_node_id): - if from_node_id in self.outputs: - # Nothing to do - return - super().add_strong_link(from_node_id, from_socket, to_node_id) - - def stage_node_execution(self): - assert self.staged_node_id is None - if self.is_empty(): - return None - available = self.get_ready_nodes() - if len(available) == 0: - raise Exception("Dependency cycle detected") - next_node = available[0] - # If an output node is available, do that first. - # Technically this has no effect on the overall length of execution, but it feels better as a user - # for a PreviewImage to display a result as soon as it can - # Some other heuristics could probably be used here to improve the UX further. - for node_id in available: - class_type = self.dynprompt.get_node(node_id)["class_type"] + def get(self, node_id): + if node_id not in self.is_changed: + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: - next_node = node_id - break - self.staged_node_id = next_node - return self.staged_node_id + if hasattr(class_def, "IS_CHANGED"): + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + else: + input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + self.is_changed[node_id] = node["is_changed"] + except: + node["is_changed"] = float("NaN") + self.is_changed[node_id] = node["is_changed"] + else: + self.is_changed[node_id] = False + return self.is_changed[node_id] - def unstage_node_execution(self): - assert self.staged_node_id is not None - self.staged_node_id = None +class CacheSet: + def __init__(self, lru_size=None): + if lru_size is None or lru_size == 0: + self.init_classic_cache() + else: + self.init_lru_cache(lru_size) + self.all = [self.outputs, self.ui, self.objects] - def complete_node_execution(self): - node_id = self.staged_node_id - self.pop_node(node_id) - self.staged_node_id = None + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + # Performs like the old cache -- dump data ASAP + def init_classic_cache(self): + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID) + self.objects = HierarchicalCache(CacheKeySetID) -class DynamicPrompt: - def __init__(self, original_prompt): - # The original prompt provided by the user - self.original_prompt = original_prompt - # Any extra pieces of the graph created during execution - self.ephemeral_prompt = {} - self.ephemeral_parents = {} - self.ephemeral_display = {} + def recursive_debug_dump(self): + result = { + "outputs": self.outputs.recursive_debug_dump(), + "ui": self.ui.recursive_debug_dump(), + } + return result - def get_node(self, node_id): - if node_id in self.ephemeral_prompt: - return self.ephemeral_prompt[node_id] - if node_id in self.original_prompt: - return self.original_prompt[node_id] - return None - - def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): - self.ephemeral_prompt[node_id] = node_info - self.ephemeral_parents[node_id] = parent_id - self.ephemeral_display[node_id] = display_id - - def get_real_node_id(self, node_id): - while node_id in self.ephemeral_parents: - node_id = self.ephemeral_parents[node_id] - return node_id - - def get_parent_node_id(self, node_id): - return self.ephemeral_parents.get(node_id, None) - - def get_display_node_id(self, node_id): - while node_id in self.ephemeral_display: - node_id = self.ephemeral_display[node_id] - return node_id - -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: @@ -188,9 +86,12 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp if is_link(input_data) and not input_info.get("rawLink", False): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs: + if outputs is None: continue # This might be a lazily-evaluated input - obj = outputs[input_unique_id][output_index] + cached_output = outputs.get(input_unique_id) + if cached_output is None: + continue + obj = cached_output[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] @@ -331,14 +232,18 @@ def format_value(x): else: return str(x) -def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results): +def non_recursive_execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) + parent_node_id = dynprompt.get_parent_node_id(unique_id) inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: + if caches.outputs.get(unique_id) is not None: + if server.client_id is not None: + cached_output = caches.ui.get(unique_id) or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -354,7 +259,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = outputs[source_node][source_output] + node_output = caches.outputs.get(source_node)[source_output] for o in node_output: resolved_output.append(o) @@ -365,15 +270,15 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, output_ui = [] has_subgraph = False else: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data) + input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": display_node_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = object_storage.get((unique_id, class_type), None) + obj = caches.objects.get(unique_id) if obj is None: obj = class_def() - object_storage[(unique_id, class_type)] = obj + caches.objects.set(unique_id, obj) if hasattr(obj, "check_lazy_status"): required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) @@ -406,11 +311,22 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, GraphBuilder.set_default_prefix(unique_id, call_index, 0) output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + caches.ui.set(unique_id, { + "meta": { + "node_id": unique_id, + "display_node": display_node_id, + "parent_node": parent_node_id, + "real_node_id": real_node_id, + }, + "output": output_ui + }) if server.client_id is not None: - server.send_sync("executed", { "node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) if has_subgraph: cached_outputs = [] + new_node_ids = [] + new_output_ids = [] + new_output_links = [] for i in range(len(output_data)): new_graph, node_outputs = output_data[i] if new_graph is None: @@ -421,8 +337,8 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, if dynprompt.get_node(node_id) is not None: raise Exception("Attempt to add duplicate node %s" % node_id) break - new_output_ids = [] for node_id, node_info in new_graph.items(): + new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) # Figure out if the newly created node is an output node @@ -430,16 +346,21 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: new_output_ids.append(node_id) - for node_id in new_output_ids: - execution_list.add_node(node_id) for i in range(len(node_outputs)): if is_link(node_outputs[i]): from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] - execution_list.add_strong_link(from_node_id, from_socket, unique_id) + new_output_links.append((from_node_id, from_socket)) cached_outputs.append((True, node_outputs)) + new_node_ids = set(new_node_ids) + for cache in caches.all: + cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + for node_id in new_output_ids: + execution_list.add_node(node_id) + for link in new_output_links: + execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.SLEEPING, None, None) - outputs[unique_id] = output_data + caches.outputs.set(unique_id, output_data) except comfy.model_management.InterruptProcessingException as iex: print("Processing interrupted") @@ -459,8 +380,9 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, input_data_formatted[name] = [format_value(x) for x in inputs] output_data_formatted = {} - for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + # TODO - Implement me + # for node_id, node_outputs in outputs.items(): + # output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] print("!!! Exception during processing !!!") print(traceback.format_exc()) @@ -479,65 +401,15 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, return (ExecutionResult.SUCCESS, None, None) -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: - #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if is_link(input_data): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True - - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete +CACHE_FOR_DEBUG_DUMP = None +def dump_cache_for_debug(): + return CACHE_FOR_DEBUG_DUMP.recursive_debug_dump() class PromptExecutor: - def __init__(self, server): - self.outputs = {} - self.object_storage = {} - self.outputs_ui = {} - self.old_prompt = {} + def __init__(self, server, lru_size=None): + self.caches = CacheSet(lru_size) + global CACHE_FOR_DEBUG_DUMP + CACHE_FOR_DEBUG_DUMP = self.caches self.server = server def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): @@ -570,18 +442,6 @@ class PromptExecutor: } self.server.send_sync("execution_error", mes, self.server.client_id) - # Next, remove the subsequent outputs since they will not be executed - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -594,59 +454,45 @@ class PromptExecutor: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - to_delete = [] - for o in self.object_storage: - if o[0] not in prompt: - to_delete += [o] - else: - p = prompt[o[0]] - if o[1] != p['class_type']: - to_delete += [o] - for o in to_delete: - d = self.object_storage.pop(o) - del d + dynamic_prompt = DynamicPrompt(prompt) + is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d + current_outputs = self.caches.outputs.all_node_ids() comfy.model_management.cleanup_models() if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + pending_subgraph_results = {} - dynamic_prompt = DynamicPrompt(prompt) executed = set() - execution_list = ExecutionList(dynamic_prompt, self.outputs) + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) for node_id in list(execute_outputs): execution_list.add_node(node_id) while not execution_list.is_empty(): node_id = execution_list.stage_node_execution() - result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.outputs, node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage, execution_list, pending_subgraph_results) + result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break elif result == ExecutionResult.SLEEPING: execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS + else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() - for x in executed: - if x in prompt: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + ui_outputs = {} + meta_outputs = {} + for ui_info in self.caches.ui.all_active_values(): + node_id = ui_info["meta"]["node_id"] + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } self.server.last_node_id = None @@ -979,12 +825,11 @@ class PromptQueue: self.server.queue_updated() return (item, i) - def task_done(self, item_id, outputs): + def task_done(self, item_id, history_result): with self.mutex: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] + self.history[prompt[1]].update(history_result) self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index a4038db4b..19f75bb40 100644 --- a/main.py +++ b/main.py @@ -84,13 +84,13 @@ def cuda_malloc_warning(): print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): - e = execution.PromptExecutor(server) + e = execution.PromptExecutor(server, lru_size=args.cache_lru) while True: item, item_id = q.get() execution_start_time = time.perf_counter() prompt_id = item[1] e.execute(item[2], prompt_id, item[3], item[4]) - q.task_done(item_id, e.outputs_ui) + q.task_done(item_id, e.history_result) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) diff --git a/server.py b/server.py index 57d5a65df..5c3c3f295 100644 --- a/server.py +++ b/server.py @@ -393,6 +393,7 @@ class PromptServer(): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -441,6 +442,22 @@ class PromptServer(): queue_info['queue_pending'] = current_queue[1] return web.json_response(queue_info) + @routes.get("/debugcache") + async def get_debugcache(request): + def custom_serialize(obj): + from comfy.caching import Unhashable + if isinstance(obj, frozenset): + try: + return dict(obj) + except: + return list(obj) + elif isinstance(obj, Unhashable): + return "NaN" + return str(obj) + def custom_dump(obj): + return json.dumps(obj, default=custom_serialize) + return web.json_response(execution.dump_cache_for_debug(), dumps=custom_dump) + @routes.post("/prompt") async def post_prompt(request): print("got prompt") @@ -610,6 +627,8 @@ class PromptServer(): if address == '': address = '0.0.0.0' + self.address = address + self.port = port if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) diff --git a/web/scripts/api.js b/web/scripts/api.js index b1d245d73..a9912eca4 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -119,7 +119,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 3b7483cdf..34676d85c 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -983,8 +983,8 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - this.nodeOutputs[detail.node] = detail.output; - const node = this.graph.getNodeById(detail.node); + this.nodeOutputs[detail.display_node] = detail.output; + const node = this.graph.getNodeById(detail.display_node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index f39939bf3..1f927b851 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,14 @@ class ComfyList { onclick: () => { app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); if (item.outputs) { - app.nodeOutputs = item.outputs; + app.nodeOutputs = {}; + for (const [key, value] of Object.entries(item.outputs)) { + if (item.meta && item.meta[key] && item.meta[key].display_node) { + app.nodeOutputs[item.meta[key].display_node] = value; + } else { + app.nodeOutputs[key] = value; + } + } } }, }),