From 36b2214e30db955a10b27ae0d58453bab99dac96 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 28 Jan 2024 20:40:18 -0800 Subject: [PATCH] Execution Model Inversion This PR inverts the execution model -- from recursively calling nodes to using a topological sort of the nodes. This change allows for modification of the node graph during execution. This allows for two major advantages: 1. The implementation of lazy evaluation in nodes. For example, if a "Mix Images" node has a mix factor of exactly 0.0, the second image input doesn't even need to be evaluated (and visa-versa if the mix factor is 1.0). 2. Dynamic expansion of nodes. This allows for the creation of dynamic "node groups". Specifically, custom nodes can return subgraphs that replace the original node in the graph. This is an incredibly powerful concept. Using this functionality, it was easy to implement: a. Components (a.k.a. node groups) b. Flow control (i.e. while loops) via tail recursion c. All-in-one nodes that replicate the WebUI functionality d. and more All of those were able to be implemented entirely via custom nodes, so those features are *not* a part of this PR. (There are some front-end changes that should occur before that functionality is made widely available, particularly around variant sockets.) The custom nodes associated with this PR can be found at: https://github.com/BadCafeCode/execution-inversion-demo-comfyui Note that some of them require that variant socket types ("*") be enabled. --- comfy/caching.py | 316 +++++++++++++++ comfy/cli_args.py | 4 + comfy/graph.py | 172 ++++++++ comfy/graph_utils.py | 140 +++++++ execution.py | 598 ++++++++++++++++------------ main.py | 4 +- server.py | 4 + tests-ui/tests/groupNode.test.js | 2 + web/extensions/core/groupNode.js | 4 +- web/extensions/core/widgetInputs.js | 2 +- web/scripts/api.js | 2 +- web/scripts/app.js | 6 +- web/scripts/ui.js | 9 +- 13 files changed, 1006 insertions(+), 257 deletions(-) create mode 100644 comfy/caching.py create mode 100644 comfy/graph.py create mode 100644 comfy/graph_utils.py diff --git a/comfy/caching.py b/comfy/caching.py new file mode 100644 index 000000000..ef047dcc5 --- /dev/null +++ b/comfy/caching.py @@ -0,0 +1,316 @@ +import itertools +from typing import Sequence, Mapping + +import nodes + +from comfy.graph_utils import is_link + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + +class Unhashable: + def __init__(self): + self.value = float("NaN") + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self): + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + +class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + + def include_node_id_in_input(self): + return True + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.dynprompt = None + self.cache_key_set = None + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + + def all_node_ids(self): + assert self.cache_key_set is not None + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def clean_unused(self): + assert self.cache_key_set is not None + preserve_keys = set(self.cache_key_set.get_used_keys()) + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def _set_immediate(self, node_id, value): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + assert self.cache_key_set is not None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.cache_key_set is not None + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + + def all_active_values(self): + active_nodes = self.all_node_ids() + result = [] + for node_id in active_nodes: + value = self.get(node_id) + if value is not None: + result.append(value) + return result + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + print("LRUCache: Now at generation %d" % self.generation) + + def clean_unused(self): + print("LRUCache: Cleaning unused. Current size: %d/%d" % (len(self.cache), self.max_size)) + while len(self.cache) > self.max_size and self.min_generation < self.generation: + print("LRUCache: Evicting generation %d" % self.min_generation) + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self + + def all_active_values(self): + explored = set() + to_explore = set(self.cache_key_set.get_used_keys()) + while len(to_explore) > 0: + cache_key = to_explore.pop() + if cache_key not in explored: + self.used_generation[cache_key] = self.generation + explored.add(cache_key) + if cache_key in self.children: + to_explore.update(self.children[cache_key]) + return [self.cache[key] for key in explored if key in self.cache] + diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b4bbfbfab..2cbefefeb 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -87,6 +87,10 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) +cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") +cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/graph.py b/comfy/graph.py new file mode 100644 index 000000000..2612317f9 --- /dev/null +++ b/comfy/graph.py @@ -0,0 +1,172 @@ +import nodes + +from comfy.graph_utils import is_link + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + return None + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) + value = inputs[to_input] + if not is_link(value): + raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = "lazy" in input_info and input_info["lazy"] + if include_lazy or not is_lazy: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, +# it can still be returned to the graph after having further dependencies added. +class ExecutionList(TopologicalSort): + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if self.output_cache.get(from_node_id) is not None: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None + available = self.get_ready_nodes() + if len(available) == 0: + raise Exception("Dependency cycle detected") + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None + +# Return this from a node and any users will be blocked with the given error message. +class ExecutionBlocker: + def __init__(self, message): + self.message = message + diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py new file mode 100644 index 000000000..a0042e078 --- /dev/null +++ b/comfy/graph_utils.py @@ -0,0 +1,140 @@ +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + +# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end +class GraphBuilder: + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix = None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() + else: + self.prefix = prefix + self.nodes = {} + self.id_gen = 1 + + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + if graph_index is not None: + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = "%s.%d.%d." % (root, call_index, graph_index) + GraphBuilder._default_prefix_graph_index += 1 + return result + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if is_link(value) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + self.override_display_id = None + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id + + def serialize(self): + serialized = { + "class_type": self.class_type, + "inputs": self.inputs + } + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id + return serialized + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = { "class_type": node_info["class_type"], "inputs": {} } + for input_name, input_value in node_info.get("inputs", {}).items(): + if is_link(input_value): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if is_link(output): + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) + diff --git a/execution.py b/execution.py index 00908eadd..3b18d2a7a 100644 --- a/execution.py +++ b/execution.py @@ -4,6 +4,7 @@ import logging import threading import heapq import traceback +from enum import Enum import inspect from typing import List, Literal, NamedTuple, Optional @@ -11,29 +12,97 @@ import torch import nodes import comfy.model_management +import comfy.graph_utils +from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy.graph_utils import is_link, GraphBuilder +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): +class ExecutionResult(Enum): + SUCCESS = 0 + FAILURE = 1 + SLEEPING = 2 + +class IsChangedCache: + def __init__(self, dynprompt, outputs_cache): + self.dynprompt = dynprompt + self.outputs_cache = outputs_cache + self.is_changed = {} + + def get(self, node_id): + if node_id not in self.is_changed: + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, "IS_CHANGED"): + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + else: + input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + self.is_changed[node_id] = node["is_changed"] + except: + node["is_changed"] = float("NaN") + self.is_changed[node_id] = node["is_changed"] + else: + self.is_changed[node_id] = False + return self.is_changed[node_id] + +class CacheSet: + def __init__(self, lru_size=None): + if lru_size is None or lru_size == 0: + self.init_classic_cache() + else: + self.init_lru_cache(lru_size) + self.all = [self.outputs, self.ui, self.objects] + + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # Performs like the old cache -- dump data ASAP + def init_classic_cache(self): + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID) + self.objects = HierarchicalCache(CacheKeySetID) + + def recursive_debug_dump(self): + result = { + "outputs": self.outputs.recursive_debug_dump(), + "ui": self.ui.recursive_debug_dump(), + } + return result + +def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + if is_link(input_data) and not input_info.get("rawLink", False): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs: - input_data_all[x] = (None,) + if outputs is None: + continue # This might be a lazily-evaluated input + cached_output = outputs.get(input_unique_id) + if cached_output is None: continue - obj = outputs[input_unique_id][output_index] + obj = cached_output[output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": input_data_all[x] = [prompt] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data['extra_pnginfo']] @@ -41,7 +110,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [unique_id] return input_data_all -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): @@ -63,51 +132,97 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): if input_is_list: if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) + execution_block = None + for k, v in input_data_all.items(): + for input in v: + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + + if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(0) + results.append(getattr(obj, func)(**input_data_all)) + else: + results.append(execution_block) elif max_len_input == 0: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)()) - else: + else: for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + input_dict = slice_dict(input_data_all, i) + execution_block = None + for k, v in input_dict.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb is not None else v + break + if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(i) + results.append(getattr(obj, func)(**input_dict)) + else: + results.append(execution_block) return results -def get_output_data(obj, input_data_all): +def merge_result_data(results, obj): + # check which outputs need concatenating + output = [] + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + return output + +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - - for r in return_values: + subgraph_results = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_subgraph = False + for i in range(len(return_values)): + r = return_values[i] if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) + if 'expand' in r: + # Perform an expansion, but do not append results + has_subgraph = True + new_graph = r['expand'] + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif 'result' in r: + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: + if isinstance(r, ExecutionBlocker): + r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - + if has_subgraph: + output = subgraph_results + elif len(results) > 0: + output = merge_result_data(results, obj) + else: + output = [] ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + return output, ui, has_subgraph def format_value(x): if x is None: @@ -117,53 +232,144 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): +def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + real_node_id = dynprompt.get_real_node_id(unique_id) + display_node_id = dynprompt.get_display_node_id(unique_id) + parent_node_id = dynprompt.get_parent_node_id(unique_id) + inputs = dynprompt.get_node(unique_id)['inputs'] + class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: - return (True, None, None) - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) - if result[0] is not True: - # Another node failed further upstream - return result + if caches.outputs.get(unique_id) is not None: + if server.client_id is not None: + cached_output = caches.ui.get(unique_id) or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + if unique_id in pending_subgraph_results: + cached_results = pending_subgraph_results[unique_id] + resolved_outputs = [] + for is_subgraph, result in cached_results: + if not is_subgraph: + resolved_outputs.append(result) + else: + resolved_output = [] + for r in result: + if is_link(r): + source_node, source_output = r[0], r[1] + node_output = caches.outputs.get(source_node)[source_output] + for o in node_output: + resolved_output.append(o) - obj = object_storage.get((unique_id, class_type), None) - if obj is None: - obj = class_def() - object_storage[(unique_id, class_type)] = obj - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + else: + resolved_output.append(r) + resolved_outputs.append(tuple(resolved_output)) + output_data = merge_result_data(resolved_outputs, class_def) + output_ui = [] + has_subgraph = False + else: + input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt.original_prompt, dynprompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = display_node_id + server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + + obj = caches.objects.get(unique_id) + if obj is None: + obj = class_def() + caches.objects.set(unique_id, obj) + + if hasattr(obj, "check_lazy_status"): + required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) + required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all] + if len(required_inputs) > 0: + for i in required_inputs: + execution_list.make_input_strong_link(unique_id, i) + return (ExecutionResult.SLEEPING, None, None) + + def execution_block_cb(block): + if block.message is not None: + mes = { + "prompt_id": prompt_id, + "node_id": unique_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": "Execution Blocked: %s" % block.message, + "exception_type": "ExecutionBlocked", + "traceback": [], + "current_inputs": [], + "current_outputs": [], + } + server.send_sync("execution_error", mes, server.client_id) + return ExecutionBlocker(None) + else: + return block + def pre_execute_cb(call_index): + GraphBuilder.set_default_prefix(unique_id, call_index, 0) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + if len(output_ui) > 0: + caches.ui.set(unique_id, { + "meta": { + "node_id": unique_id, + "display_node": display_node_id, + "parent_node": parent_node_id, + "real_node_id": real_node_id, + }, + "output": output_ui + }) + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if has_subgraph: + cached_outputs = [] + new_node_ids = [] + new_output_ids = [] + new_output_links = [] + for i in range(len(output_data)): + new_graph, node_outputs = output_data[i] + if new_graph is None: + cached_outputs.append((False, node_outputs)) + else: + # Check for conflicts + for node_id in new_graph.keys(): + if dynprompt.get_node(node_id) is not None: + raise Exception("Attempt to add duplicate node %s" % node_id) + break + for node_id, node_info in new_graph.items(): + new_node_ids.append(node_id) + display_id = node_info.get("override_display_id", unique_id) + dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + # Figure out if the newly created node is an output node + class_type = node_info["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + new_output_ids.append(node_id) + for i in range(len(node_outputs)): + if is_link(node_outputs[i]): + from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + new_output_links.append((from_node_id, from_socket)) + cached_outputs.append((True, node_outputs)) + new_node_ids = set(new_node_ids) + for cache in caches.all: + cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + for node_id in new_output_ids: + execution_list.add_node(node_id) + for link in new_output_links: + execution_list.add_strong_link(link[0], link[1], unique_id) + pending_subgraph_results[unique_id] = cached_outputs + return (ExecutionResult.SLEEPING, None, None) + caches.outputs.set(unique_id, output_data) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs error_details = { - "node_id": unique_id, + "node_id": real_node_id, } - return (False, error_details, iex) + return (ExecutionResult.FAILURE, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -173,109 +379,32 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - output_data_formatted = {} - for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] - logging.error("!!! Exception during processing !!!") logging.error(traceback.format_exc()) error_details = { - "node_id": unique_id, + "node_id": real_node_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), - "current_inputs": input_data_formatted, - "current_outputs": output_data_formatted + "current_inputs": input_data_formatted } - return (False, error_details, ex) + return (ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) - return (True, None, None) - -def recursive_will_execute(prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) - - return will_execute + [unique_id] - -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: - #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = is_changed - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True - - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete + return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server): + def __init__(self, server, lru_size=None): + self.lru_size = lru_size self.server = server self.reset() def reset(self): - self.outputs = {} - self.object_storage = {} - self.outputs_ui = {} + self.caches = CacheSet(self.lru_size) self.status_messages = [] self.success = True - self.old_prompt = {} def add_message(self, event, data, broadcast: bool): self.status_messages.append((event, data)) @@ -302,7 +431,6 @@ class PromptExecutor: "node_id": node_id, "node_type": class_type, "executed": list(executed), - "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], @@ -311,18 +439,6 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) - # Next, remove the subsequent outputs since they will not be executed - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -335,61 +451,45 @@ class PromptExecutor: self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): - #delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - to_delete = [] - for o in self.object_storage: - if o[0] not in prompt: - to_delete += [o] - else: - p = prompt[o[0]] - if o[1] != p['class_type']: - to_delete += [o] - for o in to_delete: - d = self.object_storage.pop(o) - del d + dynamic_prompt = DynamicPrompt(prompt) + is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d + current_outputs = self.caches.outputs.all_node_ids() comfy.model_management.cleanup_models() self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) + pending_subgraph_results = {} executed = set() - output_node_id = None - to_execute = [] - + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) for node_id in list(execute_outputs): - to_execute += [(0, node_id)] + execution_list.add_node(node_id) - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - output_node_id = to_execute.pop(0)[-1] - - # This call shouldn't raise anything if there's an error deep in - # the actual SD code, instead it will report the node where the - # error was raised - self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if self.success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + while not execution_list.is_empty(): + node_id = execution_list.stage_node_execution() + result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break + elif result == ExecutionResult.SLEEPING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + ui_outputs = {} + meta_outputs = {} + for ui_info in self.caches.ui.all_active_values(): + node_id = ui_info["meta"]["node_id"] + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() @@ -406,7 +506,7 @@ def validate_inputs(prompt, item, validated): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] valid = True @@ -415,22 +515,23 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args - for x in required_inputs: + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info = required_inputs[x] - type_input = info[0] + info = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -501,6 +602,9 @@ def validate_inputs(prompt, item, validated): if type_input == "STRING": val = str(val) inputs[x] = val + if type_input == "BOOLEAN": + val = bool(val) + inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", @@ -516,33 +620,32 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue + } + errors.append(error) + continue if x not in validate_function_inputs: if isinstance(type_input, list): @@ -582,7 +685,7 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") for x in input_filtered: for i, r in enumerate(ret): - if r is not True: + if r is not True and not isinstance(r, ExecutionBlocker): details = f"{x}" if r is not False: details += f" - {str(r)}" @@ -741,7 +844,7 @@ class PromptQueue: completed: bool messages: List[str] - def task_done(self, item_id, outputs, + def task_done(self, item_id, history_result, status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) @@ -754,9 +857,10 @@ class PromptQueue: self.history[prompt[1]] = { "prompt": prompt, - "outputs": copy.deepcopy(outputs), + "outputs": {}, 'status': status_dict, } + self.history[prompt[1]].update(history_result) self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 69d9bce6c..8cd869e48 100644 --- a/main.py +++ b/main.py @@ -91,7 +91,7 @@ def cuda_malloc_warning(): print("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") def prompt_worker(q, server): - e = execution.PromptExecutor(server) + e = execution.PromptExecutor(server, lru_size=args.cache_lru) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -111,7 +111,7 @@ def prompt_worker(q, server): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, - e.outputs_ui, + e.history_result, status=execution.PromptQueue.ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, diff --git a/server.py b/server.py index dca06f6fc..8f2896b1b 100644 --- a/server.py +++ b/server.py @@ -396,6 +396,7 @@ class PromptServer(): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -632,6 +633,9 @@ class PromptServer(): site = web.TCPSite(runner, address, port) await site.start() + self.address = address + self.port = port + if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index e6ebedd91..15b784d67 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -443,6 +443,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${nodes.save.id}`, + display_node: `${nodes.save.id}`, output: { images: [ { @@ -483,6 +484,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${group.id}:5`, + display_node: `${group.id}:5`, output: { images: [ { diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 0f041fcd2..b78d33aac 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -956,8 +956,8 @@ export class GroupNodeHandler { const executed = handleEvent.call( this, "executed", - (d) => d?.node, - (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution }) + (d) => d?.display_node, + (d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution }) ); const onRemoved = node.onRemoved; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 3f1c1f8c1..f89c731e6 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js"; import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; -const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); const TARGET = Symbol(); // Used for reroutes to specify the real target widget diff --git a/web/scripts/api.js b/web/scripts/api.js index 3a9bcc87a..ae3fbd13a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -126,7 +126,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 6df393ba6..d16878454 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1255,7 +1255,7 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - const output = this.nodeOutputs[detail.node]; + const output = this.nodeOutputs[detail.display_node]; if (detail.merge && output) { for (const k in detail.output ?? {}) { const v = output[k]; @@ -1266,9 +1266,9 @@ export class ComfyApp { } } } else { - this.nodeOutputs[detail.node] = detail.output; + this.nodeOutputs[detail.display_node] = detail.output; } - const node = this.graph.getNodeById(detail.node); + const node = this.graph.getNodeById(detail.display_node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index d4835c6e4..d69434993 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -227,7 +227,14 @@ class ComfyList { onclick: async () => { await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); if (item.outputs) { - app.nodeOutputs = item.outputs; + app.nodeOutputs = {}; + for (const [key, value] of Object.entries(item.outputs)) { + if (item.meta && item.meta[key] && item.meta[key].display_node) { + app.nodeOutputs[item.meta[key].display_node] = value; + } else { + app.nodeOutputs[key] = value; + } + } } }, }),