diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py index 869f6154d..f7aacaa7c 100644 --- a/comfy/graph_utils.py +++ b/comfy/graph_utils.py @@ -14,16 +14,37 @@ def is_link(obj): # The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end class GraphBuilder: - def __init__(self, prefix = True): - if isinstance(prefix, str): - self.prefix = prefix - elif prefix: - self.prefix = "%d.%d." % (random.randint(0, 0xffffffffffffffff), random.randint(0, 0xffffffffffffffff)) + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix = None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() else: - self.prefix = "" + self.prefix = prefix self.nodes = {} self.id_gen = 1 + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + if graph_index is not None: + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = "%s.%d.%d." % (root, call_index, graph_index) + GraphBuilder._default_prefix_graph_index += 1 + return result + def node(self, class_type, id=None, **kwargs): if id is None: id = str(self.id_gen) @@ -73,7 +94,7 @@ class Node: self.id = id self.class_type = class_type self.inputs = inputs - self.override_parent_id = None + self.override_display_id = None def out(self, index): return [self.id, index] @@ -88,16 +109,16 @@ class Node: def get_input(self, key): return self.inputs.get(key) - def set_override_parent_id(self, override_parent_id): - self.override_parent_id = override_parent_id + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id def serialize(self): serialized = { "class_type": self.class_type, "inputs": self.inputs } - if self.override_parent_id is not None: - serialized["override_parent_id"] = self.override_parent_id + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id return serialized def add_graph_prefix(graph, outputs, prefix): diff --git a/custom_nodes/execution-inversion-demo-comfyui/components.py b/custom_nodes/execution-inversion-demo-comfyui/components.py index f6d5b904a..fb017db3d 100644 --- a/custom_nodes/execution-inversion-demo-comfyui/components.py +++ b/custom_nodes/execution-inversion-demo-comfyui/components.py @@ -3,6 +3,7 @@ import shutil import folder_paths import json import copy +import comfy.graph_utils comfy_path = os.path.dirname(folder_paths.__file__) js_path = os.path.join(comfy_path, "web", "extensions") @@ -192,8 +193,10 @@ def LoadComponent(component_file): for input_node in component_inputs: if input_node["name"] in kwargs: new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]] + outputs = tuple([[node["node_id"], 0] for node in component_outputs]) + new_graph, outputs = comfy.graph_utils.add_graph_prefix(new_graph, outputs, comfy.graph_utils.GraphBuilder.alloc_prefix()) return { - "result": tuple([[node["node_id"], 0] for node in component_outputs]), + "result": outputs, "expand": new_graph, } ComponentNode.__name__ = component_raw_name diff --git a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py index 6dc23c810..fd71f88b5 100644 --- a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py +++ b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py @@ -99,14 +99,16 @@ class WhileLoopClose: contained[unique_id] = True contained[open_node] = True + # We'll use the default prefix, but to avoid having node names grow exponentially in size, + # we'll use "Recurse" for the name of the recursively-generated copy of this node. graph = GraphBuilder() for node_id in contained: original_node = dynprompt.get_node(node_id) - node = graph.node(original_node["class_type"], node_id) - node.set_override_parent_id(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) for node_id in contained: original_node = dynprompt.get_node(node_id) - node = graph.lookup_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) for k, v in original_node["inputs"].items(): if is_link(v) and v[0] in contained: parent = graph.lookup_node(v[0]) @@ -117,7 +119,7 @@ class WhileLoopClose: for i in range(NUM_FLOW_SOCKETS): key = "initial_value%d" % i new_open.set_input(key, kwargs.get(key, None)) - my_clone = graph.lookup_node(unique_id) + my_clone = graph.lookup_node("Recurse" ) result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) return { "result": tuple(result), diff --git a/execution.py b/execution.py index 9fcd26b97..9d9cf8573 100644 --- a/execution.py +++ b/execution.py @@ -14,7 +14,7 @@ import nodes import comfy.model_management import comfy.graph_utils -from comfy.graph_utils import is_link, ExecutionBlocker +from comfy.graph_utils import is_link, ExecutionBlocker, GraphBuilder class ExecutionResult(Enum): SUCCESS = 0 @@ -45,11 +45,9 @@ def get_input_info(class_def, input_name): # ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, # it can still be returned to the graph after having further dependencies added. -class ExecutionList: - def __init__(self, dynprompt, outputs): +class TopologicalSort: + def __init__(self, dynprompt): self.dynprompt = dynprompt - self.outputs = outputs - self.staged_node_id = None self.pendingNodes = {} self.blockCount = {} # Number of nodes this node is directly blocked by self.blocking = {} # Which nodes are blocked by this node @@ -70,9 +68,6 @@ class ExecutionList: self.add_strong_link(from_node_id, from_socket, to_node_id) def add_strong_link(self, from_node_id, from_socket, to_node_id): - if from_node_id in self.outputs: - # Nothing to do - return self.add_node(from_node_id) if to_node_id not in self.blocking[from_node_id]: self.blocking[from_node_id][to_node_id] = {} @@ -95,11 +90,35 @@ class ExecutionList: if "lazy" not in input_info or not input_info["lazy"]: self.add_strong_link(from_node_id, from_socket, unique_id) + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +class ExecutionList(TopologicalSort): + def __init__(self, dynprompt, outputs): + super().__init__(dynprompt) + self.outputs = outputs + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if from_node_id in self.outputs: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): return None - available = [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + available = self.get_ready_nodes() if len(available) == 0: raise Exception("Dependency cycle detected") next_node = available[0] @@ -122,14 +141,9 @@ class ExecutionList: def complete_node_execution(self): node_id = self.staged_node_id - del self.pendingNodes[node_id] - for blocked_node_id in self.blocking[node_id]: - self.blockCount[blocked_node_id] -= 1 - del self.blocking[node_id] + self.pop_node(node_id) self.staged_node_id = None - def is_empty(self): - return len(self.pendingNodes) == 0 class DynamicPrompt: def __init__(self, original_prompt): @@ -138,6 +152,7 @@ class DynamicPrompt: # Any extra pieces of the graph created during execution self.ephemeral_prompt = {} self.ephemeral_parents = {} + self.ephemeral_display = {} def get_node(self, node_id): if node_id in self.ephemeral_prompt: @@ -146,7 +161,7 @@ class DynamicPrompt: return self.original_prompt[node_id] return None - def add_ephemeral_node(self, parent_id, node_id, node_info): + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): self.ephemeral_prompt[node_id] = node_info self.ephemeral_parents[node_id] = parent_id @@ -155,6 +170,14 @@ class DynamicPrompt: node_id = self.ephemeral_parents[node_id] return node_id + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -185,7 +208,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp input_data_all[x] = [unique_id] return input_data_all -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists intput_is_list = False if hasattr(obj, "INPUT_IS_LIST"): @@ -212,6 +235,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, executi break if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(0) results.append(getattr(obj, func)(**input_data_all)) else: results.append(execution_block) @@ -226,6 +251,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, executi execution_block = execution_block_cb(v) if execution_block_cb is not None else v break if execution_block is None: + if pre_execute_cb is not None: + pre_execute_cb(i) results.append(getattr(obj, func)(**input_dict)) else: results.append(execution_block) @@ -246,12 +273,12 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -def get_output_data(obj, input_data_all, execution_block_cb=None): +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] subgraph_results = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb) + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) has_subgraph = False for i in range(len(return_values)): r = return_values[i] @@ -299,6 +326,7 @@ def format_value(x): def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) + display_node_id = dynprompt.get_display_node_id(unique_id) inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -331,8 +359,8 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, else: input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data) if server.client_id is not None: - server.last_node_id = real_node_id - server.send_sync("executing", { "node": real_node_id, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = display_node_id + server.send_sync("executing", { "node": display_node_id, "prompt_id": prompt_id }, server.client_id) obj = object_storage.get((unique_id, class_type), None) if obj is None: @@ -366,11 +394,13 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, return ExecutionBlocker(None) else: return block - output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb) + def pre_execute_cb(call_index): + GraphBuilder.set_default_prefix(unique_id, call_index, 0) + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: - server.send_sync("executed", { "node": real_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) if has_subgraph: cached_outputs = [] for i in range(len(output_data)): @@ -381,12 +411,12 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, # Check for conflicts for node_id in new_graph.keys(): if dynprompt.get_node(node_id) is not None: - new_graph, node_outputs = comfy.graph_utils.add_graph_prefix(new_graph, node_outputs, "%s.%d." % (unique_id, i)) + raise Exception("Attempt to add duplicate node %s" % node_id) break new_output_ids = [] for node_id, node_info in new_graph.items(): - parent_id = node_info.get("override_parent_id", real_node_id) - dynprompt.add_ephemeral_node(parent_id, node_id, node_info) + display_id = node_info.get("override_display_id", unique_id) + dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) # Figure out if the newly created node is an output node class_type = node_info["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]