mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 15:50:49 +08:00
Execution Model Inversion
This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:
1. The implementation of lazy evaluation in nodes. For example, if a
"Mix Images" node has a mix factor of exactly 0.0, the second image
input doesn't even need to be evaluated (and visa-versa if the mix
factor is 1.0).
2. Dynamic expansion of nodes. This allows for the creation of dynamic
"node groups". Specifically, custom nodes can return subgraphs that
replace the original node in the graph. This is an incredibly
powerful concept. Using this functionality, it was easy to
implement:
a. Components (a.k.a. node groups)
b. Flow control (i.e. while loops) via tail recursion
c. All-in-one nodes that replicate the WebUI functionality
d. and more
All of those were able to be implemented entirely via custom nodes,
so those features are *not* a part of this PR. (There are some
front-end changes that should occur before that functionality is
made widely available, particularly around variant sockets.)
The custom nodes associated with this PR can be found at:
https://github.com/BadCafeCode/execution-inversion-demo-comfyui
Note that some of them require that variant socket types ("*") be
enabled.
This commit is contained in:
parent
079dbf9198
commit
36b2214e30
316
comfy/caching.py
Normal file
316
comfy/caching.py
Normal file
@ -0,0 +1,316 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
import nodes
|
||||
|
||||
from comfy.graph_utils import is_link
|
||||
|
||||
class CacheKeySet:
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.keys.keys())
|
||||
|
||||
def get_used_keys(self):
|
||||
return self.keys.values()
|
||||
|
||||
def get_used_subcache_keys(self):
|
||||
return self.subcache_keys.values()
|
||||
|
||||
def get_data_key(self, node_id):
|
||||
return self.keys.get(node_id, None)
|
||||
|
||||
def get_subcache_key(self, node_id):
|
||||
return self.subcache_keys.get(node_id, None)
|
||||
|
||||
class Unhashable:
|
||||
def __init__(self):
|
||||
self.value = float("NaN")
|
||||
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
return Unhashable()
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = (node_id, node["class_type"])
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
class CacheKeySetInputSignature(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self):
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
(ancestor_id, ancestor_socket) = inputs[key]
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
if ancestor_id not in order_mapping:
|
||||
ancestors.append(ancestor_id)
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
def include_node_id_in_input(self):
|
||||
return True
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
self.key_class = key_class
|
||||
self.dynprompt = None
|
||||
self.cache_key_set = None
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
|
||||
def all_node_ids(self):
|
||||
assert self.cache_key_set is not None
|
||||
node_ids = self.cache_key_set.all_node_ids()
|
||||
for subcache in self.subcaches.values():
|
||||
node_ids = node_ids.union(subcache.all_node_ids())
|
||||
return node_ids
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.cache_key_set is not None
|
||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||
to_remove = []
|
||||
for key in self.cache:
|
||||
if key not in preserve_keys:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
|
||||
to_remove = []
|
||||
for key in self.subcaches:
|
||||
if key not in preserve_subcaches:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.subcaches[key]
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.cache_key_set is not None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
assert self.cache_key_set is not None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
assert self.cache_key_set is not None
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
assert self.cache_key_set is not None
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
return self.subcaches[subcache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = []
|
||||
for key in self.cache:
|
||||
result.append({"key": key, "value": self.cache[key]})
|
||||
for key in self.subcaches:
|
||||
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
||||
return result
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||
if parent_id is None:
|
||||
return self
|
||||
|
||||
hierarchy = []
|
||||
while parent_id is not None:
|
||||
hierarchy.append(parent_id)
|
||||
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
||||
|
||||
cache = self
|
||||
for parent_id in reversed(hierarchy):
|
||||
cache = cache._get_subcache(parent_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache
|
||||
|
||||
def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache._get_immediate(node_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
def all_active_values(self):
|
||||
active_nodes = self.all_node_ids()
|
||||
result = []
|
||||
for node_id in active_nodes:
|
||||
value = self.get(node_id)
|
||||
if value is not None:
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
print("LRUCache: Now at generation %d" % self.generation)
|
||||
|
||||
def clean_unused(self):
|
||||
print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size))
|
||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||
print("LRUCache: Evicting generation %d" % self.min_generation)
|
||||
self.min_generation += 1
|
||||
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
del self.used_generation[key]
|
||||
if key in self.children:
|
||||
del self.children[key]
|
||||
|
||||
def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
for child_id in children_ids:
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
def all_active_values(self):
|
||||
explored = set()
|
||||
to_explore = set(self.cache_key_set.get_used_keys())
|
||||
while len(to_explore) > 0:
|
||||
cache_key = to_explore.pop()
|
||||
if cache_key not in explored:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
explored.add(cache_key)
|
||||
if cache_key in self.children:
|
||||
to_explore.update(self.children[cache_key])
|
||||
return [self.cache[key] for key in explored if key in self.cache]
|
||||
|
||||
@ -87,6 +87,10 @@ class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
|
||||
172
comfy/graph.py
Normal file
172
comfy/graph.py
Normal file
@ -0,0 +1,172 @@
|
||||
import nodes
|
||||
|
||||
from comfy.graph_utils import is_link
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
self.original_prompt = original_prompt
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
return None
|
||||
|
||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def get_input_info(class_def, input_name):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||
input_category = "required"
|
||||
input_info = valid_inputs["required"][input_name]
|
||||
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||
input_category = "optional"
|
||||
input_info = valid_inputs["optional"][input_name]
|
||||
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||
input_category = "hidden"
|
||||
input_info = valid_inputs["hidden"][input_name]
|
||||
if input_info is None:
|
||||
return None, None, None
|
||||
input_type = input_info[0]
|
||||
if len(input_info) > 1:
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_info(class_def, input_name)
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
|
||||
value = inputs[to_input]
|
||||
if not is_link(value):
|
||||
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
|
||||
from_node_id, from_socket = value
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = "lazy" in input_info and input_info["lazy"]
|
||||
if include_lazy or not is_lazy:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
for blocked_node_id in self.blocking[unique_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[unique_id]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
# it can still be returned to the graph after having further dependencies added.
|
||||
class ExecutionList(TopologicalSort):
|
||||
def __init__(self, dynprompt, output_cache):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if self.output_cache.get(from_node_id) is not None:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
raise Exception("Dependency cycle detected")
|
||||
next_node = available[0]
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
for node_id in available:
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
next_node = node_id
|
||||
break
|
||||
self.staged_node_id = next_node
|
||||
return self.staged_node_id
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.staged_node_id = None
|
||||
|
||||
# Return this from a node and any users will be blocked with the given error message.
|
||||
class ExecutionBlocker:
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
140
comfy/graph_utils.py
Normal file
140
comfy/graph_utils.py
Normal file
@ -0,0 +1,140 @@
|
||||
def is_link(obj):
|
||||
if not isinstance(obj, list):
|
||||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
if not isinstance(obj[0], str):
|
||||
return False
|
||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||
return False
|
||||
return True
|
||||
|
||||
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||
class GraphBuilder:
|
||||
_default_prefix_root = ""
|
||||
_default_prefix_call_index = 0
|
||||
_default_prefix_graph_index = 0
|
||||
|
||||
def __init__(self, prefix = None):
|
||||
if prefix is None:
|
||||
self.prefix = GraphBuilder.alloc_prefix()
|
||||
else:
|
||||
self.prefix = prefix
|
||||
self.nodes = {}
|
||||
self.id_gen = 1
|
||||
|
||||
@classmethod
|
||||
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
|
||||
cls._default_prefix_root = prefix_root
|
||||
cls._default_prefix_call_index = call_index
|
||||
if graph_index is not None:
|
||||
cls._default_prefix_graph_index = graph_index
|
||||
|
||||
@classmethod
|
||||
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
||||
if root is None:
|
||||
root = GraphBuilder._default_prefix_root
|
||||
if call_index is None:
|
||||
call_index = GraphBuilder._default_prefix_call_index
|
||||
if graph_index is None:
|
||||
graph_index = GraphBuilder._default_prefix_graph_index
|
||||
result = "%s.%d.%d." % (root, call_index, graph_index)
|
||||
GraphBuilder._default_prefix_graph_index += 1
|
||||
return result
|
||||
|
||||
def node(self, class_type, id=None, **kwargs):
|
||||
if id is None:
|
||||
id = str(self.id_gen)
|
||||
self.id_gen += 1
|
||||
id = self.prefix + id
|
||||
if id in self.nodes:
|
||||
return self.nodes[id]
|
||||
|
||||
node = Node(id, class_type, kwargs)
|
||||
self.nodes[id] = node
|
||||
return node
|
||||
|
||||
def lookup_node(self, id):
|
||||
id = self.prefix + id
|
||||
return self.nodes.get(id)
|
||||
|
||||
def finalize(self):
|
||||
output = {}
|
||||
for node_id, node in self.nodes.items():
|
||||
output[node_id] = node.serialize()
|
||||
return output
|
||||
|
||||
def replace_node_output(self, node_id, index, new_value):
|
||||
node_id = self.prefix + node_id
|
||||
to_remove = []
|
||||
for node in self.nodes.values():
|
||||
for key, value in node.inputs.items():
|
||||
if is_link(value) and value[0] == node_id and value[1] == index:
|
||||
if new_value is None:
|
||||
to_remove.append((node, key))
|
||||
else:
|
||||
node.inputs[key] = new_value
|
||||
for node, key in to_remove:
|
||||
del node.inputs[key]
|
||||
|
||||
def remove_node(self, id):
|
||||
id = self.prefix + id
|
||||
del self.nodes[id]
|
||||
|
||||
class Node:
|
||||
def __init__(self, id, class_type, inputs):
|
||||
self.id = id
|
||||
self.class_type = class_type
|
||||
self.inputs = inputs
|
||||
self.override_display_id = None
|
||||
|
||||
def out(self, index):
|
||||
return [self.id, index]
|
||||
|
||||
def set_input(self, key, value):
|
||||
if value is None:
|
||||
if key in self.inputs:
|
||||
del self.inputs[key]
|
||||
else:
|
||||
self.inputs[key] = value
|
||||
|
||||
def get_input(self, key):
|
||||
return self.inputs.get(key)
|
||||
|
||||
def set_override_display_id(self, override_display_id):
|
||||
self.override_display_id = override_display_id
|
||||
|
||||
def serialize(self):
|
||||
serialized = {
|
||||
"class_type": self.class_type,
|
||||
"inputs": self.inputs
|
||||
}
|
||||
if self.override_display_id is not None:
|
||||
serialized["override_display_id"] = self.override_display_id
|
||||
return serialized
|
||||
|
||||
def add_graph_prefix(graph, outputs, prefix):
|
||||
# Change the node IDs and any internal links
|
||||
new_graph = {}
|
||||
for node_id, node_info in graph.items():
|
||||
# Make sure the added nodes have unique IDs
|
||||
new_node_id = prefix + node_id
|
||||
new_node = { "class_type": node_info["class_type"], "inputs": {} }
|
||||
for input_name, input_value in node_info.get("inputs", {}).items():
|
||||
if is_link(input_value):
|
||||
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
||||
else:
|
||||
new_node["inputs"][input_name] = input_value
|
||||
new_graph[new_node_id] = new_node
|
||||
|
||||
# Change the node IDs in the outputs
|
||||
new_outputs = []
|
||||
for n in range(len(outputs)):
|
||||
output = outputs[n]
|
||||
if is_link(output):
|
||||
new_outputs.append([prefix + output[0], output[1]])
|
||||
else:
|
||||
new_outputs.append(output)
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
||||
598
execution.py
598
execution.py
@ -4,6 +4,7 @@ import logging
|
||||
import threading
|
||||
import heapq
|
||||
import traceback
|
||||
from enum import Enum
|
||||
import inspect
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
|
||||
@ -11,29 +12,97 @@ import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.graph_utils
|
||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy.graph_utils import is_link, GraphBuilder
|
||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
SLEEPING = 2
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
if node_id not in self.is_changed:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, "IS_CHANGED"):
|
||||
if "is_changed" in node:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
else:
|
||||
input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
||||
try:
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
except:
|
||||
node["is_changed"] = float("NaN")
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
else:
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
"ui": self.ui.recursive_debug_dump(),
|
||||
}
|
||||
return result
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||
if is_link(input_data) and not input_info.get("rawLink", False):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
input_data_all[x] = (None,)
|
||||
if outputs is None:
|
||||
continue # This might be a lazily-evaluated input
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
continue
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
obj = cached_output[output_index]
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = [input_data]
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
@ -41,7 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
input_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
@ -63,51 +132,97 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
execution_block = None
|
||||
for k, v in input_data_all.items():
|
||||
for input in v:
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
|
||||
break
|
||||
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None:
|
||||
pre_execute_cb(0)
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)())
|
||||
else:
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
execution_block = None
|
||||
for k, v in input_dict.items():
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
|
||||
break
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None:
|
||||
pre_execute_cb(i)
|
||||
results.append(getattr(obj, func)(**input_dict))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
subgraph_results = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
if 'expand' in r:
|
||||
# Perform an expansion, but do not append results
|
||||
has_subgraph = True
|
||||
new_graph = r['expand']
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif 'result' in r:
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
results.append(r)
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
if has_subgraph:
|
||||
output = subgraph_results
|
||||
elif len(results) > 0:
|
||||
output = merge_result_data(results, obj)
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
return output, ui, has_subgraph
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
@ -117,53 +232,144 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return (True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
if unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
if not is_subgraph:
|
||||
resolved_outputs.append(result)
|
||||
else:
|
||||
resolved_output = []
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
else:
|
||||
resolved_output.append(r)
|
||||
resolved_outputs.append(tuple(resolved_output))
|
||||
output_data = merge_result_data(resolved_outputs, class_def)
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all]
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.SLEEPING, None, None)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": "Execution Blocked: %s" % block.message,
|
||||
"exception_type": "ExecutionBlocked",
|
||||
"traceback": [],
|
||||
"current_inputs": [],
|
||||
"current_outputs": [],
|
||||
}
|
||||
server.send_sync("execution_error", mes, server.client_id)
|
||||
return ExecutionBlocker(None)
|
||||
else:
|
||||
return block
|
||||
def pre_execute_cb(call_index):
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
"display_node": display_node_id,
|
||||
"parent_node": parent_node_id,
|
||||
"real_node_id": real_node_id,
|
||||
},
|
||||
"output": output_ui
|
||||
})
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
new_output_ids = []
|
||||
new_output_links = []
|
||||
for i in range(len(output_data)):
|
||||
new_graph, node_outputs = output_data[i]
|
||||
if new_graph is None:
|
||||
cached_outputs.append((False, node_outputs))
|
||||
else:
|
||||
# Check for conflicts
|
||||
for node_id in new_graph.keys():
|
||||
if dynprompt.get_node(node_id) is not None:
|
||||
raise Exception("Attempt to add duplicate node %s" % node_id)
|
||||
break
|
||||
for node_id, node_info in new_graph.items():
|
||||
new_node_ids.append(node_id)
|
||||
display_id = node_info.get("override_display_id", unique_id)
|
||||
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||
# Figure out if the newly created node is an output node
|
||||
class_type = node_info["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
new_output_ids.append(node_id)
|
||||
for i in range(len(node_outputs)):
|
||||
if is_link(node_outputs[i]):
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
new_output_links.append((from_node_id, from_socket))
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.SLEEPING, None, None)
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
}
|
||||
|
||||
return (False, error_details, iex)
|
||||
return (ExecutionResult.FAILURE, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
@ -173,109 +379,32 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
logging.error("!!! Exception during processing !!!")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
"current_inputs": input_data_formatted
|
||||
}
|
||||
return (False, error_details, ex)
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
will_execute = []
|
||||
if unique_id in outputs:
|
||||
return []
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
|
||||
|
||||
return will_execute + [unique_id]
|
||||
|
||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
is_changed_old = ''
|
||||
is_changed = ''
|
||||
to_delete = False
|
||||
if hasattr(class_def, 'IS_CHANGED'):
|
||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
||||
if 'is_changed' not in prompt[unique_id]:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
else:
|
||||
is_changed = prompt[unique_id]['is_changed']
|
||||
|
||||
if unique_id not in outputs:
|
||||
return True
|
||||
|
||||
if not to_delete:
|
||||
if is_changed != is_changed_old:
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
to_delete = True
|
||||
if to_delete:
|
||||
break
|
||||
else:
|
||||
to_delete = True
|
||||
|
||||
if to_delete:
|
||||
d = outputs.pop(unique_id)
|
||||
del d
|
||||
return to_delete
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server):
|
||||
def __init__(self, server, lru_size=None):
|
||||
self.lru_size = lru_size
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.outputs = {}
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
self.old_prompt = {}
|
||||
|
||||
def add_message(self, event, data, broadcast: bool):
|
||||
self.status_messages.append((event, data))
|
||||
@ -302,7 +431,6 @@ class PromptExecutor:
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
@ -311,18 +439,6 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
@ -335,61 +451,45 @@ class PromptExecutor:
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
#delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in prompt:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
to_delete = []
|
||||
for o in self.object_storage:
|
||||
if o[0] not in prompt:
|
||||
to_delete += [o]
|
||||
else:
|
||||
p = prompt[o[0]]
|
||||
if o[1] != p['class_type']:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.object_storage.pop(o)
|
||||
del d
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
for x in prompt:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
|
||||
comfy.model_management.cleanup_models()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||
if self.success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
while not execution_list.is_empty():
|
||||
node_id = execution_list.stage_node_execution()
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.SLEEPING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for ui_info in self.caches.ui.all_active_values():
|
||||
node_id = ui_info["meta"]["node_id"]
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
@ -406,7 +506,7 @@ def validate_inputs(prompt, item, validated):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||
|
||||
errors = []
|
||||
valid = True
|
||||
@ -415,22 +515,23 @@ def validate_inputs(prompt, item, validated):
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
||||
|
||||
for x in required_inputs:
|
||||
for x in valid_inputs:
|
||||
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||
if x not in inputs:
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
if input_category == "required":
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info = required_inputs[x]
|
||||
type_input = info[0]
|
||||
info = (type_input, extra_info)
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
error = {
|
||||
@ -501,6 +602,9 @@ def validate_inputs(prompt, item, validated):
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
if type_input == "BOOLEAN":
|
||||
val = bool(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
@ -516,33 +620,32 @@ def validate_inputs(prompt, item, validated):
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if "min" in info[1] and val < info[1]["min"]:
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
if "min" in extra_info and val < extra_info["min"]:
|
||||
error = {
|
||||
"type": "value_smaller_than_min",
|
||||
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in info[1] and val > info[1]["max"]:
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
if "max" in extra_info and val > extra_info["max"]:
|
||||
error = {
|
||||
"type": "value_bigger_than_max",
|
||||
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if x not in validate_function_inputs:
|
||||
if isinstance(type_input, list):
|
||||
@ -582,7 +685,7 @@ def validate_inputs(prompt, item, validated):
|
||||
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
@ -741,7 +844,7 @@ class PromptQueue:
|
||||
completed: bool
|
||||
messages: List[str]
|
||||
|
||||
def task_done(self, item_id, outputs,
|
||||
def task_done(self, item_id, history_result,
|
||||
status: Optional['PromptQueue.ExecutionStatus']):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
@ -754,9 +857,10 @@ class PromptQueue:
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
"outputs": copy.deepcopy(outputs),
|
||||
"outputs": {},
|
||||
'status': status_dict,
|
||||
}
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
|
||||
def get_current_queue(self):
|
||||
|
||||
4
main.py
4
main.py
@ -91,7 +91,7 @@ def cuda_malloc_warning():
|
||||
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
def prompt_worker(q, server):
|
||||
e = execution.PromptExecutor(server)
|
||||
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@ -111,7 +111,7 @@ def prompt_worker(q, server):
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id,
|
||||
e.outputs_ui,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
|
||||
@ -396,6 +396,7 @@ class PromptServer():
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
@ -632,6 +633,9 @@ class PromptServer():
|
||||
site = web.TCPSite(runner, address, port)
|
||||
await site.start()
|
||||
|
||||
self.address = address
|
||||
self.port = port
|
||||
|
||||
if verbose:
|
||||
print("Starting server\n")
|
||||
print("To see the GUI go to: http://{}:{}".format(address, port))
|
||||
|
||||
@ -443,6 +443,7 @@ describe("group node", () => {
|
||||
new CustomEvent("executed", {
|
||||
detail: {
|
||||
node: `${nodes.save.id}`,
|
||||
display_node: `${nodes.save.id}`,
|
||||
output: {
|
||||
images: [
|
||||
{
|
||||
@ -483,6 +484,7 @@ describe("group node", () => {
|
||||
new CustomEvent("executed", {
|
||||
detail: {
|
||||
node: `${group.id}:5`,
|
||||
display_node: `${group.id}:5`,
|
||||
output: {
|
||||
images: [
|
||||
{
|
||||
|
||||
@ -956,8 +956,8 @@ export class GroupNodeHandler {
|
||||
const executed = handleEvent.call(
|
||||
this,
|
||||
"executed",
|
||||
(d) => d?.node,
|
||||
(d, id, node) => ({ ...d, node: id, merge: !node.resetExecution })
|
||||
(d) => d?.display_node,
|
||||
(d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution })
|
||||
);
|
||||
|
||||
const onRemoved = node.onRemoved;
|
||||
|
||||
@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js";
|
||||
import { applyTextReplacements } from "../../scripts/utils.js";
|
||||
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||
const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"];
|
||||
const CONFIG = Symbol();
|
||||
const GET_CONFIG = Symbol();
|
||||
const TARGET = Symbol(); // Used for reroutes to specify the real target widget
|
||||
|
||||
@ -126,7 +126,7 @@ class ComfyApi extends EventTarget {
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
|
||||
@ -1255,7 +1255,7 @@ export class ComfyApp {
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
const output = this.nodeOutputs[detail.node];
|
||||
const output = this.nodeOutputs[detail.display_node];
|
||||
if (detail.merge && output) {
|
||||
for (const k in detail.output ?? {}) {
|
||||
const v = output[k];
|
||||
@ -1266,9 +1266,9 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
this.nodeOutputs[detail.display_node] = detail.output;
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
const node = this.graph.getNodeById(detail.display_node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
|
||||
@ -227,7 +227,14 @@ class ComfyList {
|
||||
onclick: async () => {
|
||||
await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
|
||||
if (item.outputs) {
|
||||
app.nodeOutputs = item.outputs;
|
||||
app.nodeOutputs = {};
|
||||
for (const [key, value] of Object.entries(item.outputs)) {
|
||||
if (item.meta && item.meta[key] && item.meta[key].display_node) {
|
||||
app.nodeOutputs[item.meta[key].display_node] = value;
|
||||
} else {
|
||||
app.nodeOutputs[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user