Rework Caching (#2)

This commit solves a number of bugs and adds some caching related
functionality. Specifically:

1. Caching is now input-based. In cases of completely identical nodes,
   the output will be reused (for example, if you have multiple
   LoadCheckpoint nodes loading the same checkpoint). If a node doesn't
   want this behavior (e.g. a `RandomInteger` node, it should set
   `NOT_IDEMPOTENT = True`.
2. This means that nodes within a component will now be cached and will
   only change if the input actually changes. Note that types that can't
   be hashed by default will always count as changed (though the
   component itself will only expand if one of its inputs changes).
3. A new LRU caching strategy is now available by starting with
   `--cache-lru 100`. With this strategy, in addition to the latest
   workflow being cached, up to N (100 in the example) node outputs will
   be retained. This allows users to work on multiple workflows or
   experiment with different inputs without losing the benefits of
   caching (at the cost of more RAM and VRAM). I intentionally left some
   additional debug print statements in for this strategy for the
   moment.
4. A new endpoint `/debugcache` has been temporarily added to assist
   with tracking down issues people encounter. It allows you to browse
   the contents of the cache.
5. Outputs from ephemeral nodes will now be communicated to the
   front-end with both the ephemeral node id, the 'parent' node id, and
   the 'display' node id. The front-end has been updated to deal with
   this.
This commit is contained in:
guill 2023-09-11 19:53:41 -07:00 committed by GitHub
parent f15bd84351
commit 7d4530f6f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 639 additions and 282 deletions

316
comfy/caching.py Normal file
View File

@ -0,0 +1,316 @@
import itertools
from typing import Sequence, Mapping
import nodes
from comfy.graph_utils import is_link
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(node_ids):
raise NotImplementedError()
def all_node_ids(self):
return set(self.keys.keys())
def get_used_keys(self):
return self.keys.values()
def get_used_subcache_keys(self):
return self.subcache_keys.values()
def get_data_key(self, node_id):
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self):
return False
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
def include_node_id_in_input(self):
return True
class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
self.dynprompt = None
self.cache_key_set = None
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache
def all_node_ids(self):
assert self.cache_key_set is not None
node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids
def clean_unused(self):
assert self.cache_key_set is not None
preserve_keys = set(self.cache_key_set.get_used_keys())
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]
to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]
def _set_immediate(self, node_id, value):
assert self.cache_key_set is not None
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
assert self.cache_key_set is not None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def _ensure_subcache(self, node_id, children_ids):
assert self.cache_key_set is not None
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
assert self.cache_key_set is not None
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
return self.subcaches[subcache_key]
else:
return None
def recursive_debug_dump(self):
result = []
for key in self.cache:
result.append({"key": key, "value": self.cache[key]})
for key in self.subcaches:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None:
return self
hierarchy = []
while parent_id is not None:
hierarchy.append(parent_id)
parent_id = self.dynprompt.get_parent_node_id(parent_id)
cache = self
for parent_id in reversed(hierarchy):
cache = cache._get_subcache(parent_id)
if cache is None:
return None
return cache
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
def all_active_values(self):
active_nodes = self.all_node_ids()
result = []
for node_id in active_nodes:
value = self.get(node_id)
if value is not None:
result.append(value)
return result
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
print("LRUCache: Now at generation %d" % self.generation)
def clean_unused(self):
print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size))
while len(self.cache) > self.max_size and self.min_generation < self.generation:
print("LRUCache: Evicting generation %d" % self.min_generation)
self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
for key in to_remove:
del self.cache[key]
del self.used_generation[key]
if key in self.children:
del self.children[key]
def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
for child_id in children_ids:
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
def all_active_values(self):
explored = set()
to_explore = set(self.cache_key_set.get_used_keys())
while len(to_explore) > 0:
cache_key = to_explore.pop()
if cache_key not in explored:
self.used_generation[cache_key] = self.generation
explored.add(cache_key)
if cache_key in self.children:
to_explore.update(self.children[cache_key])
return [self.cache[key] for key in explored if key in self.cache]

View File

@ -69,6 +69,10 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")

166
comfy/graph.py Normal file
View File

@ -0,0 +1,166 @@
import nodes
from comfy.graph_utils import is_link
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
return None
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
value = inputs[to_input]
if not is_link(value):
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added.
class ExecutionList(TopologicalSort):
def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if self.output_cache.get(from_node_id) is not None:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None
available = self.get_ready_nodes()
if len(available) == 0:
raise Exception("Dependency cycle detected")
next_node = available[0]
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
for node_id in available:
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
next_node = node_id
break
self.staged_node_id = next_node
return self.staged_node_id
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.staged_node_id = None

View File

@ -6,7 +6,6 @@ import threading
import heapq
import traceback
import gc
import time
from enum import Enum
import torch
@ -14,172 +13,71 @@ import nodes
import comfy.model_management
import comfy.graph_utils
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt
from comfy.graph_utils import is_link, ExecutionBlocker, GraphBuilder
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
SLEEPING = 2
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added.
class TopologicalSort:
def __init__(self, dynprompt):
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
self.outputs_cache = outputs_cache
self.is_changed = {}
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
value = inputs[to_input]
if not is_link(value):
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
if "lazy" not in input_info or not input_info["lazy"]:
self.add_strong_link(from_node_id, from_socket, unique_id)
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort):
def __init__(self, dynprompt, outputs):
super().__init__(dynprompt)
self.outputs = outputs
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if from_node_id in self.outputs:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None
available = self.get_ready_nodes()
if len(available) == 0:
raise Exception("Dependency cycle detected")
next_node = available[0]
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
for node_id in available:
class_type = self.dynprompt.get_node(node_id)["class_type"]
def get(self, node_id):
if node_id not in self.is_changed:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
next_node = node_id
break
self.staged_node_id = next_node
return self.staged_node_id
if hasattr(class_def, "IS_CHANGED"):
if "is_changed" in node:
self.is_changed[node_id] = node["is_changed"]
else:
input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
try:
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
self.is_changed[node_id] = node["is_changed"]
except:
node["is_changed"] = float("NaN")
self.is_changed[node_id] = node["is_changed"]
else:
self.is_changed[node_id] = False
return self.is_changed[node_id]
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.staged_node_id = None
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
self.objects = HierarchicalCache(CacheKeySetID)
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
return None
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}):
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
@ -188,9 +86,12 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
if is_link(input_data) and not input_info.get("rawLink", False):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
if outputs is None:
continue # This might be a lazily-evaluated input
obj = outputs[input_unique_id][output_index]
cached_output = outputs.get(input_unique_id)
if cached_output is None:
continue
obj = cached_output[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
@ -331,14 +232,18 @@ def format_value(x):
else:
return str(x)
def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results):
def non_recursive_execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
@ -354,7 +259,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = outputs[source_node][source_output]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
resolved_output.append(o)
@ -365,15 +270,15 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
output_ui = []
has_subgraph = False
else:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data)
input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": display_node_id, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = object_storage.get((unique_id, class_type), None)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
@ -406,11 +311,22 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
caches.ui.set(unique_id, {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
if server.client_id is not None:
server.send_sync("executed", { "node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
new_output_ids = []
new_output_links = []
for i in range(len(output_data)):
new_graph, node_outputs = output_data[i]
if new_graph is None:
@ -421,8 +337,8 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
if dynprompt.get_node(node_id) is not None:
raise Exception("Attempt to add duplicate node %s" % node_id)
break
new_output_ids = []
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
@ -430,16 +346,21 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for node_id in new_output_ids:
execution_list.add_node(node_id)
for i in range(len(node_outputs)):
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
execution_list.add_strong_link(from_node_id, from_socket, unique_id)
new_output_links.append((from_node_id, from_socket))
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.SLEEPING, None, None)
outputs[unique_id] = output_data
caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
@ -459,8 +380,9 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
# TODO - Implement me
# for node_id, node_outputs in outputs.items():
# output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!")
print(traceback.format_exc())
@ -479,65 +401,15 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
return (ExecutionResult.SUCCESS, None, None)
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_changed_old = ''
is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except:
to_delete = True
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs:
return True
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if is_link(input_data):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if to_delete:
d = outputs.pop(unique_id)
del d
return to_delete
CACHE_FOR_DEBUG_DUMP = None
def dump_cache_for_debug():
return CACHE_FOR_DEBUG_DUMP.recursive_debug_dump()
class PromptExecutor:
def __init__(self, server):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.old_prompt = {}
def __init__(self, server, lru_size=None):
self.caches = CacheSet(lru_size)
global CACHE_FOR_DEBUG_DUMP
CACHE_FOR_DEBUG_DUMP = self.caches
self.server = server
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
@ -570,18 +442,6 @@ class PromptExecutor:
}
self.server.send_sync("execution_error", mes, self.server.client_id)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
@ -594,59 +454,45 @@ class PromptExecutor:
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
current_outputs = self.caches.outputs.all_node_ids()
comfy.model_management.cleanup_models()
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
pending_subgraph_results = {}
dynamic_prompt = DynamicPrompt(prompt)
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.outputs)
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
while not execution_list.is_empty():
node_id = execution_list.stage_node_execution()
result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.outputs, node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage, execution_list, pending_subgraph_results)
result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.SLEEPING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
for x in executed:
if x in prompt:
self.old_prompt[x] = copy.deepcopy(prompt[x])
ui_outputs = {}
meta_outputs = {}
for ui_info in self.caches.ui.all_active_values():
node_id = ui_info["meta"]["node_id"]
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
@ -979,12 +825,11 @@ class PromptQueue:
self.server.queue_updated()
return (item, i)
def task_done(self, item_id, outputs):
def task_done(self, item_id, history_result):
with self.mutex:
prompt = self.currently_running.pop(item_id)
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
for o in outputs:
self.history[prompt[1]]["outputs"][o] = outputs[o]
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
def get_current_queue(self):

View File

@ -84,13 +84,13 @@ def cuda_malloc_warning():
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
while True:
item, item_id = q.get()
execution_start_time = time.perf_counter()
prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4])
q.task_done(item_id, e.outputs_ui)
q.task_done(item_id, e.history_result)
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)

View File

@ -393,6 +393,7 @@ class PromptServer():
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
@ -441,6 +442,22 @@ class PromptServer():
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@routes.get("/debugcache")
async def get_debugcache(request):
def custom_serialize(obj):
from comfy.caching import Unhashable
if isinstance(obj, frozenset):
try:
return dict(obj)
except:
return list(obj)
elif isinstance(obj, Unhashable):
return "NaN"
return str(obj)
def custom_dump(obj):
return json.dumps(obj, default=custom_serialize)
return web.json_response(execution.dump_cache_for_debug(), dumps=custom_dump)
@routes.post("/prompt")
async def post_prompt(request):
print("got prompt")
@ -610,6 +627,8 @@ class PromptServer():
if address == '':
address = '0.0.0.0'
self.address = address
self.port = port
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))

View File

@ -119,7 +119,7 @@ class ComfyApi extends EventTarget {
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));

View File

@ -983,8 +983,8 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output;
const node = this.graph.getNodeById(detail.node);
this.nodeOutputs[detail.display_node] = detail.output;
const node = this.graph.getNodeById(detail.display_node);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);

View File

@ -465,7 +465,14 @@ class ComfyList {
onclick: () => {
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) {
app.nodeOutputs = item.outputs;
app.nodeOutputs = {};
for (const [key, value] of Object.entries(item.outputs)) {
if (item.meta && item.meta[key] && item.meta[key].display_node) {
app.nodeOutputs[item.meta[key].display_node] = value;
} else {
app.nodeOutputs[key] = value;
}
}
}
},
}),