mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Code cleanup
While implementing caching for components, I did some cleanup. Despite the fact that subgraph caching is put on hold for now (in favor of a larger cache refactor later), these are the changes that I think are worth keeping anyway. * Makes subgraph node IDs deterministic * Allows usage of the topological sort without execution * Tracks parent nodes (i.e. those that caused a node to be created) and display nodes (i.e. the one we want to highlight while an ephemeral node is executing) separately.
This commit is contained in:
parent
dbb5a3122a
commit
a86d383ff3
@ -14,16 +14,37 @@ def is_link(obj):
|
||||
|
||||
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||
class GraphBuilder:
|
||||
def __init__(self, prefix = True):
|
||||
if isinstance(prefix, str):
|
||||
self.prefix = prefix
|
||||
elif prefix:
|
||||
self.prefix = "%d.%d." % (random.randint(0, 0xffffffffffffffff), random.randint(0, 0xffffffffffffffff))
|
||||
_default_prefix_root = ""
|
||||
_default_prefix_call_index = 0
|
||||
_default_prefix_graph_index = 0
|
||||
|
||||
def __init__(self, prefix = None):
|
||||
if prefix is None:
|
||||
self.prefix = GraphBuilder.alloc_prefix()
|
||||
else:
|
||||
self.prefix = ""
|
||||
self.prefix = prefix
|
||||
self.nodes = {}
|
||||
self.id_gen = 1
|
||||
|
||||
@classmethod
|
||||
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
|
||||
cls._default_prefix_root = prefix_root
|
||||
cls._default_prefix_call_index = call_index
|
||||
if graph_index is not None:
|
||||
cls._default_prefix_graph_index = graph_index
|
||||
|
||||
@classmethod
|
||||
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
||||
if root is None:
|
||||
root = GraphBuilder._default_prefix_root
|
||||
if call_index is None:
|
||||
call_index = GraphBuilder._default_prefix_call_index
|
||||
if graph_index is None:
|
||||
graph_index = GraphBuilder._default_prefix_graph_index
|
||||
result = "%s.%d.%d." % (root, call_index, graph_index)
|
||||
GraphBuilder._default_prefix_graph_index += 1
|
||||
return result
|
||||
|
||||
def node(self, class_type, id=None, **kwargs):
|
||||
if id is None:
|
||||
id = str(self.id_gen)
|
||||
@ -73,7 +94,7 @@ class Node:
|
||||
self.id = id
|
||||
self.class_type = class_type
|
||||
self.inputs = inputs
|
||||
self.override_parent_id = None
|
||||
self.override_display_id = None
|
||||
|
||||
def out(self, index):
|
||||
return [self.id, index]
|
||||
@ -88,16 +109,16 @@ class Node:
|
||||
def get_input(self, key):
|
||||
return self.inputs.get(key)
|
||||
|
||||
def set_override_parent_id(self, override_parent_id):
|
||||
self.override_parent_id = override_parent_id
|
||||
def set_override_display_id(self, override_display_id):
|
||||
self.override_display_id = override_display_id
|
||||
|
||||
def serialize(self):
|
||||
serialized = {
|
||||
"class_type": self.class_type,
|
||||
"inputs": self.inputs
|
||||
}
|
||||
if self.override_parent_id is not None:
|
||||
serialized["override_parent_id"] = self.override_parent_id
|
||||
if self.override_display_id is not None:
|
||||
serialized["override_display_id"] = self.override_display_id
|
||||
return serialized
|
||||
|
||||
def add_graph_prefix(graph, outputs, prefix):
|
||||
|
||||
@ -3,6 +3,7 @@ import shutil
|
||||
import folder_paths
|
||||
import json
|
||||
import copy
|
||||
import comfy.graph_utils
|
||||
|
||||
comfy_path = os.path.dirname(folder_paths.__file__)
|
||||
js_path = os.path.join(comfy_path, "web", "extensions")
|
||||
@ -192,8 +193,10 @@ def LoadComponent(component_file):
|
||||
for input_node in component_inputs:
|
||||
if input_node["name"] in kwargs:
|
||||
new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]]
|
||||
outputs = tuple([[node["node_id"], 0] for node in component_outputs])
|
||||
new_graph, outputs = comfy.graph_utils.add_graph_prefix(new_graph, outputs, comfy.graph_utils.GraphBuilder.alloc_prefix())
|
||||
return {
|
||||
"result": tuple([[node["node_id"], 0] for node in component_outputs]),
|
||||
"result": outputs,
|
||||
"expand": new_graph,
|
||||
}
|
||||
ComponentNode.__name__ = component_raw_name
|
||||
|
||||
@ -99,14 +99,16 @@ class WhileLoopClose:
|
||||
contained[unique_id] = True
|
||||
contained[open_node] = True
|
||||
|
||||
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
|
||||
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
|
||||
graph = GraphBuilder()
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.node(original_node["class_type"], node_id)
|
||||
node.set_override_parent_id(node_id)
|
||||
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
||||
node.set_override_display_id(node_id)
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.lookup_node(node_id)
|
||||
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
||||
for k, v in original_node["inputs"].items():
|
||||
if is_link(v) and v[0] in contained:
|
||||
parent = graph.lookup_node(v[0])
|
||||
@ -117,7 +119,7 @@ class WhileLoopClose:
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
key = "initial_value%d" % i
|
||||
new_open.set_input(key, kwargs.get(key, None))
|
||||
my_clone = graph.lookup_node(unique_id)
|
||||
my_clone = graph.lookup_node("Recurse" )
|
||||
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
|
||||
return {
|
||||
"result": tuple(result),
|
||||
|
||||
82
execution.py
82
execution.py
@ -14,7 +14,7 @@ import nodes
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.graph_utils
|
||||
from comfy.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy.graph_utils import is_link, ExecutionBlocker, GraphBuilder
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
@ -45,11 +45,9 @@ def get_input_info(class_def, input_name):
|
||||
|
||||
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
# it can still be returned to the graph after having further dependencies added.
|
||||
class ExecutionList:
|
||||
def __init__(self, dynprompt, outputs):
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs = outputs
|
||||
self.staged_node_id = None
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
@ -70,9 +68,6 @@ class ExecutionList:
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if from_node_id in self.outputs:
|
||||
# Nothing to do
|
||||
return
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
@ -95,11 +90,35 @@ class ExecutionList:
|
||||
if "lazy" not in input_info or not input_info["lazy"]:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
for blocked_node_id in self.blocking[unique_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[unique_id]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
class ExecutionList(TopologicalSort):
|
||||
def __init__(self, dynprompt, outputs):
|
||||
super().__init__(dynprompt)
|
||||
self.outputs = outputs
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if from_node_id in self.outputs:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None
|
||||
available = [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
raise Exception("Dependency cycle detected")
|
||||
next_node = available[0]
|
||||
@ -122,14 +141,9 @@ class ExecutionList:
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
del self.pendingNodes[node_id]
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[node_id]
|
||||
self.pop_node(node_id)
|
||||
self.staged_node_id = None
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
@ -138,6 +152,7 @@ class DynamicPrompt:
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
@ -146,7 +161,7 @@ class DynamicPrompt:
|
||||
return self.original_prompt[node_id]
|
||||
return None
|
||||
|
||||
def add_ephemeral_node(self, parent_id, node_id, node_info):
|
||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
|
||||
@ -155,6 +170,14 @@ class DynamicPrompt:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
@ -185,7 +208,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None):
|
||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
@ -212,6 +235,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, executi
|
||||
break
|
||||
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None:
|
||||
pre_execute_cb(0)
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
@ -226,6 +251,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, executi
|
||||
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
|
||||
break
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None:
|
||||
pre_execute_cb(i)
|
||||
results.append(getattr(obj, func)(**input_dict))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
@ -246,12 +273,12 @@ def merge_result_data(results, obj):
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None):
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb)
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
@ -299,6 +326,7 @@ def format_value(x):
|
||||
def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
@ -331,8 +359,8 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
else:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = real_node_id
|
||||
server.send_sync("executing", { "node": real_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
@ -366,11 +394,13 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
return ExecutionBlocker(None)
|
||||
else:
|
||||
return block
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb)
|
||||
def pre_execute_cb(call_index):
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": real_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
for i in range(len(output_data)):
|
||||
@ -381,12 +411,12 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
|
||||
# Check for conflicts
|
||||
for node_id in new_graph.keys():
|
||||
if dynprompt.get_node(node_id) is not None:
|
||||
new_graph, node_outputs = comfy.graph_utils.add_graph_prefix(new_graph, node_outputs, "%s.%d." % (unique_id, i))
|
||||
raise Exception("Attempt to add duplicate node %s" % node_id)
|
||||
break
|
||||
new_output_ids = []
|
||||
for node_id, node_info in new_graph.items():
|
||||
parent_id = node_info.get("override_parent_id", real_node_id)
|
||||
dynprompt.add_ephemeral_node(parent_id, node_id, node_info)
|
||||
display_id = node_info.get("override_display_id", unique_id)
|
||||
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||
# Figure out if the newly created node is an output node
|
||||
class_type = node_info["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user