Execution Model Inversion

This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:

    1. The implementation of lazy evaluation in nodes. For example, if a
    "Mix Images" node has a mix factor of exactly 0.0, the second image
    input doesn't even need to be evaluated (and visa-versa if the mix
    factor is 1.0).

    2. Dynamic expansion of nodes. This allows for the creation of dynamic
    "node groups". Specifically, custom nodes can return subgraphs that
    replace the original node in the graph. This is an incredibly
    powerful concept. Using this functionality, it was easy to
    implement:
        a. Components (a.k.a. node groups)
        b. Flow control (i.e. while loops) via tail recursion
        c. All-in-one nodes that replicate the WebUI functionality
        d. and more
    All of those were able to be implemented entirely via custom nodes,
    so those features are *not* a part of this PR. (There are some
    front-end changes that should occur before that functionality is
    made widely available, particularly around variant sockets.)

The custom nodes associated with this PR can be found at:
https://github.com/BadCafeCode/execution-inversion-demo-comfyui

Note that some of them require that variant socket types ("*") be
enabled.
This commit is contained in:
Jacob Segal 2024-01-28 20:40:18 -08:00
parent 079dbf9198
commit 36b2214e30
13 changed files with 1006 additions and 257 deletions

316
comfy/caching.py Normal file
View File

@ -0,0 +1,316 @@
import itertools
from typing import Sequence, Mapping
import nodes
from comfy.graph_utils import is_link
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
self.keys = {}
self.subcache_keys = {}
def add_keys(node_ids):
raise NotImplementedError()
def all_node_ids(self):
return set(self.keys.keys())
def get_used_keys(self):
return self.keys.values()
def get_used_subcache_keys(self):
return self.subcache_keys.values()
def get_data_key(self, node_id):
return self.keys.get(node_id, None)
def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None)
class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
else:
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.add_keys(node_ids)
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
self.dynprompt = dynprompt
self.is_changed_cache = is_changed_cache
self.add_keys(node_ids)
def include_node_id_in_input(self):
return False
def add_keys(self, node_ids):
for node_id in node_ids:
if node_id in self.keys:
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
def get_node_signature(self, dynprompt, node_id):
signature = []
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):
if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is
# deterministic based on which specific inputs the ancestor is connected by.
def get_ordered_ancestry(self, dynprompt, node_id):
ancestors = []
order_mapping = {}
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:
if is_link(inputs[key]):
ancestor_id = inputs[key][0]
if ancestor_id not in order_mapping:
ancestors.append(ancestor_id)
order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache)
def include_node_id_in_input(self):
return True
class BasicCache:
def __init__(self, key_class):
self.key_class = key_class
self.dynprompt = None
self.cache_key_set = None
self.cache = {}
self.subcaches = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
self.is_changed_cache = is_changed_cache
def all_node_ids(self):
assert self.cache_key_set is not None
node_ids = self.cache_key_set.all_node_ids()
for subcache in self.subcaches.values():
node_ids = node_ids.union(subcache.all_node_ids())
return node_ids
def clean_unused(self):
assert self.cache_key_set is not None
preserve_keys = set(self.cache_key_set.get_used_keys())
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
to_remove = []
for key in self.cache:
if key not in preserve_keys:
to_remove.append(key)
for key in to_remove:
del self.cache[key]
to_remove = []
for key in self.subcaches:
if key not in preserve_subcaches:
to_remove.append(key)
for key in to_remove:
del self.subcaches[key]
def _set_immediate(self, node_id, value):
assert self.cache_key_set is not None
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
assert self.cache_key_set is not None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
else:
return None
def _ensure_subcache(self, node_id, children_ids):
assert self.cache_key_set is not None
subcache_key = self.cache_key_set.get_subcache_key(node_id)
subcache = self.subcaches.get(subcache_key, None)
if subcache is None:
subcache = BasicCache(self.key_class)
self.subcaches[subcache_key] = subcache
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache
def _get_subcache(self, node_id):
assert self.cache_key_set is not None
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
return self.subcaches[subcache_key]
else:
return None
def recursive_debug_dump(self):
result = []
for key in self.cache:
result.append({"key": key, "value": self.cache[key]})
for key in self.subcaches:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result
class HierarchicalCache(BasicCache):
def __init__(self, key_class):
super().__init__(key_class)
def _get_cache_for(self, node_id):
parent_id = self.dynprompt.get_parent_node_id(node_id)
if parent_id is None:
return self
hierarchy = []
while parent_id is not None:
hierarchy.append(parent_id)
parent_id = self.dynprompt.get_parent_node_id(parent_id)
cache = self
for parent_id in reversed(hierarchy):
cache = cache._get_subcache(parent_id)
if cache is None:
return None
return cache
def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return cache._get_immediate(node_id)
def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
cache._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
assert cache is not None
return cache._ensure_subcache(node_id, children_ids)
def all_active_values(self):
active_nodes = self.all_node_ids()
result = []
for node_id in active_nodes:
value = self.get(node_id)
if value is not None:
result.append(value)
return result
class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100):
super().__init__(key_class)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
self.used_generation = {}
self.children = {}
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
super().set_prompt(dynprompt, node_ids, is_changed_cache)
self.generation += 1
for node_id in node_ids:
self._mark_used(node_id)
print("LRUCache: Now at generation %d" % self.generation)
def clean_unused(self):
print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size))
while len(self.cache) > self.max_size and self.min_generation < self.generation:
print("LRUCache: Evicting generation %d" % self.min_generation)
self.min_generation += 1
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
for key in to_remove:
del self.cache[key]
del self.used_generation[key]
if key in self.children:
del self.children[key]
def get(self, node_id):
self._mark_used(node_id)
return self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
def set(self, node_id, value):
self._mark_used(node_id)
return self._set_immediate(node_id, value)
def ensure_subcache_for(self, node_id, children_ids):
self.cache_key_set.add_keys(children_ids)
self._mark_used(node_id)
cache_key = self.cache_key_set.get_data_key(node_id)
self.children[cache_key] = []
for child_id in children_ids:
self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
def all_active_values(self):
explored = set()
to_explore = set(self.cache_key_set.get_used_keys())
while len(to_explore) > 0:
cache_key = to_explore.pop()
if cache_key not in explored:
self.used_generation[cache_key] = self.generation
explored.add(cache_key)
if cache_key in self.children:
to_explore.update(self.children[cache_key])
return [self.cache[key] for key in explored if key in self.cache]

View File

@ -87,6 +87,10 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")

172
comfy/graph.py Normal file
View File

@ -0,0 +1,172 @@
import nodes
from comfy.graph_utils import is_link
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
return None
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
self.ephemeral_display[node_id] = display_id
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if "required" in valid_inputs and input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
value = inputs[to_input]
if not is_link(value):
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
is_lazy = "lazy" in input_info and input_info["lazy"]
if include_lazy or not is_lazy:
self.add_strong_link(from_node_id, from_socket, unique_id)
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added.
class ExecutionList(TopologicalSort):
def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt)
self.output_cache = output_cache
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if self.output_cache.get(from_node_id) is not None:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None
available = self.get_ready_nodes()
if len(available) == 0:
raise Exception("Dependency cycle detected")
next_node = available[0]
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
for node_id in available:
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
next_node = node_id
break
self.staged_node_id = next_node
return self.staged_node_id
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
def complete_node_execution(self):
node_id = self.staged_node_id
self.pop_node(node_id)
self.staged_node_id = None
# Return this from a node and any users will be blocked with the given error message.
class ExecutionBlocker:
def __init__(self, message):
self.message = message

140
comfy/graph_utils.py Normal file
View File

@ -0,0 +1,140 @@
def is_link(obj):
if not isinstance(obj, list):
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
return False
return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
_default_prefix_root = ""
_default_prefix_call_index = 0
_default_prefix_graph_index = 0
def __init__(self, prefix = None):
if prefix is None:
self.prefix = GraphBuilder.alloc_prefix()
else:
self.prefix = prefix
self.nodes = {}
self.id_gen = 1
@classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index
if graph_index is not None:
cls._default_prefix_graph_index = graph_index
@classmethod
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
if root is None:
root = GraphBuilder._default_prefix_root
if call_index is None:
call_index = GraphBuilder._default_prefix_call_index
if graph_index is None:
graph_index = GraphBuilder._default_prefix_graph_index
result = "%s.%d.%d." % (root, call_index, graph_index)
GraphBuilder._default_prefix_graph_index += 1
return result
def node(self, class_type, id=None, **kwargs):
if id is None:
id = str(self.id_gen)
self.id_gen += 1
id = self.prefix + id
if id in self.nodes:
return self.nodes[id]
node = Node(id, class_type, kwargs)
self.nodes[id] = node
return node
def lookup_node(self, id):
id = self.prefix + id
return self.nodes.get(id)
def finalize(self):
output = {}
for node_id, node in self.nodes.items():
output[node_id] = node.serialize()
return output
def replace_node_output(self, node_id, index, new_value):
node_id = self.prefix + node_id
to_remove = []
for node in self.nodes.values():
for key, value in node.inputs.items():
if is_link(value) and value[0] == node_id and value[1] == index:
if new_value is None:
to_remove.append((node, key))
else:
node.inputs[key] = new_value
for node, key in to_remove:
del node.inputs[key]
def remove_node(self, id):
id = self.prefix + id
del self.nodes[id]
class Node:
def __init__(self, id, class_type, inputs):
self.id = id
self.class_type = class_type
self.inputs = inputs
self.override_display_id = None
def out(self, index):
return [self.id, index]
def set_input(self, key, value):
if value is None:
if key in self.inputs:
del self.inputs[key]
else:
self.inputs[key] = value
def get_input(self, key):
return self.inputs.get(key)
def set_override_display_id(self, override_display_id):
self.override_display_id = override_display_id
def serialize(self):
serialized = {
"class_type": self.class_type,
"inputs": self.inputs
}
if self.override_display_id is not None:
serialized["override_display_id"] = self.override_display_id
return serialized
def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links
new_graph = {}
for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
for input_name, input_value in node_info.get("inputs", {}).items():
if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
else:
new_node["inputs"][input_name] = input_value
new_graph[new_node_id] = new_node
# Change the node IDs in the outputs
new_outputs = []
for n in range(len(outputs)):
output = outputs[n]
if is_link(output):
new_outputs.append([prefix + output[0], output[1]])
else:
new_outputs.append(output)
return new_graph, tuple(new_outputs)

View File

@ -4,6 +4,7 @@ import logging
import threading
import heapq
import traceback
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
@ -11,29 +12,97 @@ import torch
import nodes
import comfy.model_management
import comfy.graph_utils
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy.graph_utils import is_link, GraphBuilder
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
SLEEPING = 2
class IsChangedCache:
def __init__(self, dynprompt, outputs_cache):
self.dynprompt = dynprompt
self.outputs_cache = outputs_cache
self.is_changed = {}
def get(self, node_id):
if node_id not in self.is_changed:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, "IS_CHANGED"):
if "is_changed" in node:
self.is_changed[node_id] = node["is_changed"]
else:
input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
try:
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
self.is_changed[node_id] = node["is_changed"]
except:
node["is_changed"] = float("NaN")
self.is_changed[node_id] = node["is_changed"]
else:
self.is_changed[node_id] = False
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_type, input_category, input_info = get_input_info(class_def, x)
if is_link(input_data) and not input_info.get("rawLink", False):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
input_data_all[x] = (None,)
if outputs is None:
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
continue
obj = outputs[input_unique_id][output_index]
obj = cached_output[output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
elif input_category is not None:
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = [prompt]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = [extra_data['extra_pnginfo']]
@ -41,7 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = [unique_id]
return input_data_all
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
@ -63,51 +132,97 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
if input_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
execution_block = None
for k, v in input_data_all.items():
for input in v:
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break
if execution_block is None:
if pre_execute_cb is not None:
pre_execute_cb(0)
results.append(getattr(obj, func)(**input_data_all))
else:
results.append(execution_block)
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)())
else:
else:
for i in range(max_len_input):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
input_dict = slice_dict(input_data_all, i)
execution_block = None
for k, v in input_dict.items():
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break
if execution_block is None:
if pre_execute_cb is not None:
pre_execute_cb(i)
results.append(getattr(obj, func)(**input_dict))
else:
results.append(execution_block)
return results
def get_output_data(obj, input_data_all):
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
subgraph_results = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
results.append(r['result'])
if 'expand' in r:
# Perform an expansion, but do not append results
has_subgraph = True
new_graph = r['expand']
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
subgraph_results.append((new_graph, result))
elif 'result' in r:
result = r.get("result", None)
if isinstance(result, ExecutionBlocker):
result = tuple([result] * len(obj.RETURN_TYPES))
results.append(result)
subgraph_results.append((None, result))
else:
if isinstance(r, ExecutionBlocker):
r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
else:
output = []
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
return output, ui, has_subgraph
def format_value(x):
if x is None:
@ -117,53 +232,144 @@ def format_value(x):
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
parent_node_id = dynprompt.get_parent_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
if unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
if not is_subgraph:
resolved_outputs.append(result)
else:
resolved_output = []
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
resolved_output.append(o)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
else:
resolved_output.append(r)
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
has_subgraph = False
else:
input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data)
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id)
if obj is None:
obj = class_def()
caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"):
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.SLEEPING, None, None)
def execution_block_cb(block):
if block.message is not None:
mes = {
"prompt_id": prompt_id,
"node_id": unique_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": "Execution Blocked: %s" % block.message,
"exception_type": "ExecutionBlocked",
"traceback": [],
"current_inputs": [],
"current_outputs": [],
}
server.send_sync("execution_error", mes, server.client_id)
return ExecutionBlocker(None)
else:
return block
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
"parent_node": parent_node_id,
"real_node_id": real_node_id,
},
"output": output_ui
})
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
new_node_ids = []
new_output_ids = []
new_output_links = []
for i in range(len(output_data)):
new_graph, node_outputs = output_data[i]
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.get_node(node_id) is not None:
raise Exception("Attempt to add duplicate node %s" % node_id)
break
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for i in range(len(node_outputs)):
if is_link(node_outputs[i]):
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
new_output_links.append((from_node_id, from_socket))
cached_outputs.append((True, node_outputs))
new_node_ids = set(new_node_ids)
for cache in caches.all:
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.SLEEPING, None, None)
caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
}
return (False, error_details, iex)
return (ExecutionResult.FAILURE, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
@ -173,109 +379,32 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
output_data_formatted = {}
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
logging.error("!!! Exception during processing !!!")
logging.error(traceback.format_exc())
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
"current_inputs": input_data_formatted
}
return (False, error_details, ex)
return (ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
return []
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
return will_execute + [unique_id]
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
is_changed_old = ''
is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else:
is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs:
return True
if not to_delete:
if is_changed != is_changed_old:
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id in outputs:
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
else:
to_delete = True
if to_delete:
break
else:
to_delete = True
if to_delete:
d = outputs.pop(unique_id)
del d
return to_delete
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server):
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
self.server = server
self.reset()
def reset(self):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.caches = CacheSet(self.lru_size)
self.status_messages = []
self.success = True
self.old_prompt = {}
def add_message(self, event, data, broadcast: bool):
self.status_messages.append((event, data))
@ -302,7 +431,6 @@ class PromptExecutor:
"node_id": node_id,
"node_type": class_type,
"executed": list(executed),
"exception_message": error["exception_message"],
"exception_type": error["exception_type"],
"traceback": error["traceback"],
@ -311,18 +439,6 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
# Next, remove the subsequent outputs since they will not be executed
to_delete = []
for o in self.outputs:
if (o not in current_outputs) and (o not in executed):
to_delete += [o]
if o in self.old_prompt:
d = self.old_prompt.pop(o)
del d
for o in to_delete:
d = self.outputs.pop(o)
del d
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)
@ -335,61 +451,45 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
to_delete = []
for o in self.object_storage:
if o[0] not in prompt:
to_delete += [o]
else:
p = prompt[o[0]]
if o[1] != p['class_type']:
to_delete += [o]
for o in to_delete:
d = self.object_storage.pop(o)
del d
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys())
for x in list(self.outputs_ui.keys()):
if x not in current_outputs:
d = self.outputs_ui.pop(x)
del d
current_outputs = self.caches.outputs.all_node_ids()
comfy.model_management.cleanup_models()
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
executed = set()
output_node_id = None
to_execute = []
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
for node_id in list(execute_outputs):
to_execute += [(0, node_id)]
execution_list.add_node(node_id)
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if self.success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
while not execution_list.is_empty():
node_id = execution_list.stage_node_execution()
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.SLEEPING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
ui_outputs = {}
meta_outputs = {}
for ui_info in self.caches.ui.all_active_values():
node_id = ui_info["meta"]["node_id"]
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
@ -406,7 +506,7 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = []
valid = True
@ -415,22 +515,23 @@ def validate_inputs(prompt, item, validated):
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs:
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
if x not in inputs:
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
if input_category == "required":
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
}
errors.append(error)
errors.append(error)
continue
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
info = (type_input, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@ -501,6 +602,9 @@ def validate_inputs(prompt, item, validated):
if type_input == "STRING":
val = str(val)
inputs[x] = val
if type_input == "BOOLEAN":
val = bool(val)
inputs[x] = val
except Exception as ex:
error = {
"type": "invalid_input_type",
@ -516,33 +620,32 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
errors.append(error)
continue
}
errors.append(error)
continue
if x not in validate_function_inputs:
if isinstance(type_input, list):
@ -582,7 +685,7 @@ def validate_inputs(prompt, item, validated):
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
if r is not True and not isinstance(r, ExecutionBlocker):
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
@ -741,7 +844,7 @@ class PromptQueue:
completed: bool
messages: List[str]
def task_done(self, item_id, outputs,
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex:
prompt = self.currently_running.pop(item_id)
@ -754,9 +857,10 @@ class PromptQueue:
self.history[prompt[1]] = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
"outputs": {},
'status': status_dict,
}
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
def get_current_queue(self):

View File

@ -91,7 +91,7 @@ def cuda_malloc_warning():
print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@ -111,7 +111,7 @@ def prompt_worker(q, server):
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id,
e.outputs_ui,
e.history_result,
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if e.success else 'error',
completed=e.success,

View File

@ -396,6 +396,7 @@ class PromptServer():
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
@ -632,6 +633,9 @@ class PromptServer():
site = web.TCPSite(runner, address, port)
await site.start()
self.address = address
self.port = port
if verbose:
print("Starting server\n")
print("To see the GUI go to: http://{}:{}".format(address, port))

View File

@ -443,6 +443,7 @@ describe("group node", () => {
new CustomEvent("executed", {
detail: {
node: `${nodes.save.id}`,
display_node: `${nodes.save.id}`,
output: {
images: [
{
@ -483,6 +484,7 @@ describe("group node", () => {
new CustomEvent("executed", {
detail: {
node: `${group.id}:5`,
display_node: `${group.id}:5`,
output: {
images: [
{

View File

@ -956,8 +956,8 @@ export class GroupNodeHandler {
const executed = handleEvent.call(
this,
"executed",
(d) => d?.node,
(d, id, node) => ({ ...d, node: id, merge: !node.resetExecution })
(d) => d?.display_node,
(d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution })
);
const onRemoved = node.onRemoved;

View File

@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js";
import { applyTextReplacements } from "../../scripts/utils.js";
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"];
const CONFIG = Symbol();
const GET_CONFIG = Symbol();
const TARGET = Symbol(); // Used for reroutes to specify the real target widget

View File

@ -126,7 +126,7 @@ class ComfyApi extends EventTarget {
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));

View File

@ -1255,7 +1255,7 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
const output = this.nodeOutputs[detail.node];
const output = this.nodeOutputs[detail.display_node];
if (detail.merge && output) {
for (const k in detail.output ?? {}) {
const v = output[k];
@ -1266,9 +1266,9 @@ export class ComfyApp {
}
}
} else {
this.nodeOutputs[detail.node] = detail.output;
this.nodeOutputs[detail.display_node] = detail.output;
}
const node = this.graph.getNodeById(detail.node);
const node = this.graph.getNodeById(detail.display_node);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);

View File

@ -227,7 +227,14 @@ class ComfyList {
onclick: async () => {
await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow);
if (item.outputs) {
app.nodeOutputs = item.outputs;
app.nodeOutputs = {};
for (const [key, value] of Object.entries(item.outputs)) {
if (item.meta && item.meta[key] && item.meta[key].display_node) {
app.nodeOutputs[item.meta[key].display_node] = value;
} else {
app.nodeOutputs[key] = value;
}
}
}
},
}),