mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 14:02:37 +08:00
Rework Caching (#2)
This commit solves a number of bugs and adds some caching related functionality. Specifically: 1. Caching is now input-based. In cases of completely identical nodes, the output will be reused (for example, if you have multiple LoadCheckpoint nodes loading the same checkpoint). If a node doesn't want this behavior (e.g. a `RandomInteger` node, it should set `NOT_IDEMPOTENT = True`. 2. This means that nodes within a component will now be cached and will only change if the input actually changes. Note that types that can't be hashed by default will always count as changed (though the component itself will only expand if one of its inputs changes). 3. A new LRU caching strategy is now available by starting with `--cache-lru 100`. With this strategy, in addition to the latest workflow being cached, up to N (100 in the example) node outputs will be retained. This allows users to work on multiple workflows or experiment with different inputs without losing the benefits of caching (at the cost of more RAM and VRAM). I intentionally left some additional debug print statements in for this strategy for the moment. 4. A new endpoint `/debugcache` has been temporarily added to assist with tracking down issues people encounter. It allows you to browse the contents of the cache. 5. Outputs from ephemeral nodes will now be communicated to the front-end with both the ephemeral node id, the 'parent' node id, and the 'display' node id. The front-end has been updated to deal with this.
This commit is contained in:
parent
f15bd84351
commit
7d4530f6f5
316
comfy/caching.py
Normal file
316
comfy/caching.py
Normal file
@ -0,0 +1,316 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
import nodes
|
||||
|
||||
from comfy.graph_utils import is_link
|
||||
|
||||
class CacheKeySet:
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.keys.keys())
|
||||
|
||||
def get_used_keys(self):
|
||||
return self.keys.values()
|
||||
|
||||
def get_used_subcache_keys(self):
|
||||
return self.subcache_keys.values()
|
||||
|
||||
def get_data_key(self, node_id):
|
||||
return self.keys.get(node_id, None)
|
||||
|
||||
def get_subcache_key(self, node_id):
|
||||
return self.subcache_keys.get(node_id, None)
|
||||
|
||||
class Unhashable:
|
||||
def __init__(self):
|
||||
self.value = float("NaN")
|
||||
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
return Unhashable()
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = (node_id, node["class_type"])
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
class CacheKeySetInputSignature(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self):
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
(ancestor_id, ancestor_socket) = inputs[key]
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
if ancestor_id not in order_mapping:
|
||||
ancestors.append(ancestor_id)
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
def include_node_id_in_input(self):
|
||||
return True
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
self.key_class = key_class
|
||||
self.dynprompt = None
|
||||
self.cache_key_set = None
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
|
||||
def all_node_ids(self):
|
||||
assert self.cache_key_set is not None
|
||||
node_ids = self.cache_key_set.all_node_ids()
|
||||
for subcache in self.subcaches.values():
|
||||
node_ids = node_ids.union(subcache.all_node_ids())
|
||||
return node_ids
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.cache_key_set is not None
|
||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||
to_remove = []
|
||||
for key in self.cache:
|
||||
if key not in preserve_keys:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
|
||||
to_remove = []
|
||||
for key in self.subcaches:
|
||||
if key not in preserve_subcaches:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.subcaches[key]
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.cache_key_set is not None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
assert self.cache_key_set is not None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
assert self.cache_key_set is not None
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
assert self.cache_key_set is not None
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
return self.subcaches[subcache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = []
|
||||
for key in self.cache:
|
||||
result.append({"key": key, "value": self.cache[key]})
|
||||
for key in self.subcaches:
|
||||
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
||||
return result
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||
if parent_id is None:
|
||||
return self
|
||||
|
||||
hierarchy = []
|
||||
while parent_id is not None:
|
||||
hierarchy.append(parent_id)
|
||||
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
||||
|
||||
cache = self
|
||||
for parent_id in reversed(hierarchy):
|
||||
cache = cache._get_subcache(parent_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache
|
||||
|
||||
def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache._get_immediate(node_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
def all_active_values(self):
|
||||
active_nodes = self.all_node_ids()
|
||||
result = []
|
||||
for node_id in active_nodes:
|
||||
value = self.get(node_id)
|
||||
if value is not None:
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
print("LRUCache: Now at generation %d" % self.generation)
|
||||
|
||||
def clean_unused(self):
|
||||
print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size))
|
||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||
print("LRUCache: Evicting generation %d" % self.min_generation)
|
||||
self.min_generation += 1
|
||||
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
del self.used_generation[key]
|
||||
if key in self.children:
|
||||
del self.children[key]
|
||||
|
||||
def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
for child_id in children_ids:
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
def all_active_values(self):
|
||||
explored = set()
|
||||
to_explore = set(self.cache_key_set.get_used_keys())
|
||||
while len(to_explore) > 0:
|
||||
cache_key = to_explore.pop()
|
||||
if cache_key not in explored:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
explored.add(cache_key)
|
||||
if cache_key in self.children:
|
||||
to_explore.update(self.children[cache_key])
|
||||
return [self.cache[key] for key in explored if key in self.cache]
|
||||
|
||||
@ -69,6 +69,10 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
|
||||
166
comfy/graph.py
Normal file
166
comfy/graph.py
Normal file
@ -0,0 +1,166 @@
|
||||
import nodes
|
||||
|
||||
from comfy.graph_utils import is_link
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
self.original_prompt = original_prompt
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
return None
|
||||
|
||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def get_input_info(class_def, input_name):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||
input_category = "required"
|
||||
input_info = valid_inputs["required"][input_name]
|
||||
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||
input_category = "optional"
|
||||
input_info = valid_inputs["optional"][input_name]
|
||||
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||
input_category = "hidden"
|
||||
input_info = valid_inputs["hidden"][input_name]
|
||||
if input_info is None:
|
||||
return None, None, None
|
||||
input_type = input_info[0]
|
||||
if len(input_info) > 1:
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_info(class_def, input_name)
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
|
||||
value = inputs[to_input]
|
||||
if not is_link(value):
|
||||
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
|
||||
from_node_id, from_socket = value
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = "lazy" in input_info and input_info["lazy"]
|
||||
if include_lazy or not is_lazy:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
for blocked_node_id in self.blocking[unique_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[unique_id]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
# it can still be returned to the graph after having further dependencies added.
|
||||
class ExecutionList(TopologicalSort):
|
||||
def __init__(self, dynprompt, output_cache):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if self.output_cache.get(from_node_id) is not None:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
raise Exception("Dependency cycle detected")
|
||||
next_node = available[0]
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
for node_id in available:
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
next_node = node_id
|
||||
break
|
||||
self.staged_node_id = next_node
|
||||
return self.staged_node_id
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.staged_node_id = None
|
||||
397
execution.py
397
execution.py
@ -6,7 +6,6 @@ import threading
|
||||
import heapq
|
||||
import traceback
|
||||
import gc
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
@ -14,172 +13,71 @@ import nodes
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.graph_utils
|
||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt
|
||||
from comfy.graph_utils import is_link, ExecutionBlocker, GraphBuilder
|
||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
SLEEPING = 2
|
||||
|
||||
def get_input_info(class_def, input_name):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||
input_category = "required"
|
||||
input_info = valid_inputs["required"][input_name]
|
||||
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||
input_category = "optional"
|
||||
input_info = valid_inputs["optional"][input_name]
|
||||
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||
input_category = "hidden"
|
||||
input_info = valid_inputs["hidden"][input_name]
|
||||
if input_info is None:
|
||||
return None, None, None
|
||||
input_type = input_info[0]
|
||||
if len(input_info) > 1:
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
# it can still be returned to the graph after having further dependencies added.
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_info(class_def, input_name)
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
|
||||
value = inputs[to_input]
|
||||
if not is_link(value):
|
||||
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
|
||||
from_node_id, from_socket = value
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
if "lazy" not in input_info or not input_info["lazy"]:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
for blocked_node_id in self.blocking[unique_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[unique_id]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
class ExecutionList(TopologicalSort):
|
||||
def __init__(self, dynprompt, outputs):
|
||||
super().__init__(dynprompt)
|
||||
self.outputs = outputs
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if from_node_id in self.outputs:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
raise Exception("Dependency cycle detected")
|
||||
next_node = available[0]
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
for node_id in available:
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
def get(self, node_id):
|
||||
if node_id not in self.is_changed:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
next_node = node_id
|
||||
break
|
||||
self.staged_node_id = next_node
|
||||
return self.staged_node_id
|
||||
if hasattr(class_def, "IS_CHANGED"):
|
||||
if "is_changed" in node:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
else:
|
||||
input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
||||
try:
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
except:
|
||||
node["is_changed"] = float("NaN")
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
else:
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.staged_node_id = None
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
self.original_prompt = original_prompt
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
"ui": self.ui.recursive_debug_dump(),
|
||||
}
|
||||
return result
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
return None
|
||||
|
||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}):
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
for x in inputs:
|
||||
@ -188,9 +86,12 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
|
||||
if is_link(input_data) and not input_info.get("rawLink", False):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
if outputs is None:
|
||||
continue # This might be a lazily-evaluated input
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
continue
|
||||
obj = cached_output[output_index]
|
||||
input_data_all[x] = obj
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
@ -331,14 +232,18 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results):
|
||||
def non_recursive_execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
@ -354,7 +259,7 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = outputs[source_node][source_output]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
@ -365,15 +270,15 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data)
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
obj = caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
@ -406,11 +311,22 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
"display_node": display_node_id,
|
||||
"parent_node": parent_node_id,
|
||||
"real_node_id": real_node_id,
|
||||
},
|
||||
"output": output_ui
|
||||
})
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
new_output_ids = []
|
||||
new_output_links = []
|
||||
for i in range(len(output_data)):
|
||||
new_graph, node_outputs = output_data[i]
|
||||
if new_graph is None:
|
||||
@ -421,8 +337,8 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
if dynprompt.get_node(node_id) is not None:
|
||||
raise Exception("Attempt to add duplicate node %s" % node_id)
|
||||
break
|
||||
new_output_ids = []
|
||||
for node_id, node_info in new_graph.items():
|
||||
new_node_ids.append(node_id)
|
||||
display_id = node_info.get("override_display_id", unique_id)
|
||||
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||
# Figure out if the newly created node is an output node
|
||||
@ -430,16 +346,21 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
new_output_ids.append(node_id)
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for i in range(len(node_outputs)):
|
||||
if is_link(node_outputs[i]):
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
execution_list.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
new_output_links.append((from_node_id, from_socket))
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.SLEEPING, None, None)
|
||||
outputs[unique_id] = output_data
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
|
||||
@ -459,8 +380,9 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
# TODO - Implement me
|
||||
# for node_id, node_outputs in outputs.items():
|
||||
# output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
@ -479,65 +401,15 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
is_changed_old = ''
|
||||
is_changed = ''
|
||||
to_delete = False
|
||||
if hasattr(class_def, 'IS_CHANGED'):
|
||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
||||
if 'is_changed' not in prompt[unique_id]:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except:
|
||||
to_delete = True
|
||||
else:
|
||||
is_changed = prompt[unique_id]['is_changed']
|
||||
|
||||
if unique_id not in outputs:
|
||||
return True
|
||||
|
||||
if not to_delete:
|
||||
if is_changed != is_changed_old:
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if is_link(input_data):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
to_delete = True
|
||||
if to_delete:
|
||||
break
|
||||
else:
|
||||
to_delete = True
|
||||
|
||||
if to_delete:
|
||||
d = outputs.pop(unique_id)
|
||||
del d
|
||||
return to_delete
|
||||
CACHE_FOR_DEBUG_DUMP = None
|
||||
def dump_cache_for_debug():
|
||||
return CACHE_FOR_DEBUG_DUMP.recursive_debug_dump()
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
self.outputs = {}
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.old_prompt = {}
|
||||
def __init__(self, server, lru_size=None):
|
||||
self.caches = CacheSet(lru_size)
|
||||
global CACHE_FOR_DEBUG_DUMP
|
||||
CACHE_FOR_DEBUG_DUMP = self.caches
|
||||
self.server = server
|
||||
|
||||
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
|
||||
@ -570,18 +442,6 @@ class PromptExecutor:
|
||||
}
|
||||
self.server.send_sync("execution_error", mes, self.server.client_id)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
@ -594,59 +454,45 @@ class PromptExecutor:
|
||||
self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in prompt:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
to_delete = []
|
||||
for o in self.object_storage:
|
||||
if o[0] not in prompt:
|
||||
to_delete += [o]
|
||||
else:
|
||||
p = prompt[o[0]]
|
||||
if o[1] != p['class_type']:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.object_storage.pop(o)
|
||||
del d
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
for x in prompt:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
|
||||
comfy.model_management.cleanup_models()
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
|
||||
pending_subgraph_results = {}
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.outputs)
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id = execution_list.stage_node_execution()
|
||||
result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.outputs, node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage, execution_list, pending_subgraph_results)
|
||||
result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.SLEEPING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
|
||||
for x in executed:
|
||||
if x in prompt:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for ui_info in self.caches.ui.all_active_values():
|
||||
node_id = ui_info["meta"]["node_id"]
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
|
||||
|
||||
@ -979,12 +825,11 @@ class PromptQueue:
|
||||
self.server.queue_updated()
|
||||
return (item, i)
|
||||
|
||||
def task_done(self, item_id, outputs):
|
||||
def task_done(self, item_id, history_result):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
|
||||
for o in outputs:
|
||||
self.history[prompt[1]]["outputs"][o] = outputs[o]
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
|
||||
4
main.py
4
main.py
@ -84,13 +84,13 @@ def cuda_malloc_warning():
|
||||
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
||||
while True:
|
||||
item, item_id = q.get()
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
q.task_done(item_id, e.history_result)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
|
||||
19
server.py
19
server.py
@ -393,6 +393,7 @@ class PromptServer():
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
@ -441,6 +442,22 @@ class PromptServer():
|
||||
queue_info['queue_pending'] = current_queue[1]
|
||||
return web.json_response(queue_info)
|
||||
|
||||
@routes.get("/debugcache")
|
||||
async def get_debugcache(request):
|
||||
def custom_serialize(obj):
|
||||
from comfy.caching import Unhashable
|
||||
if isinstance(obj, frozenset):
|
||||
try:
|
||||
return dict(obj)
|
||||
except:
|
||||
return list(obj)
|
||||
elif isinstance(obj, Unhashable):
|
||||
return "NaN"
|
||||
return str(obj)
|
||||
def custom_dump(obj):
|
||||
return json.dumps(obj, default=custom_serialize)
|
||||
return web.json_response(execution.dump_cache_for_debug(), dumps=custom_dump)
|
||||
|
||||
@routes.post("/prompt")
|
||||
async def post_prompt(request):
|
||||
print("got prompt")
|
||||
@ -610,6 +627,8 @@ class PromptServer():
|
||||
|
||||
if address == '':
|
||||
address = '0.0.0.0'
|
||||
self.address = address
|
||||
self.port = port
|
||||
if verbose:
|
||||
print("Starting server\n")
|
||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
||||
|
||||
@ -119,7 +119,7 @@ class ComfyApi extends EventTarget {
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
|
||||
@ -983,8 +983,8 @@ export class ComfyApp {
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
this.nodeOutputs[detail.display_node] = detail.output;
|
||||
const node = this.graph.getNodeById(detail.display_node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
|
||||
@ -465,7 +465,14 @@ class ComfyList {
|
||||
onclick: () => {
|
||||
app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
|
||||
if (item.outputs) {
|
||||
app.nodeOutputs = item.outputs;
|
||||
app.nodeOutputs = {};
|
||||
for (const [key, value] of Object.entries(item.outputs)) {
|
||||
if (item.meta && item.meta[key] && item.meta[key].display_node) {
|
||||
app.nodeOutputs[item.meta[key].display_node] = value;
|
||||
} else {
|
||||
app.nodeOutputs[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user