Code cleanup

While implementing caching for components, I did some cleanup. Despite
the fact that subgraph caching is put on hold for now (in favor of a
larger cache refactor later), these are the changes that I think are
worth keeping anyway.

* Makes subgraph node IDs deterministic
* Allows usage of the topological sort without execution
* Tracks parent nodes (i.e. those that caused a node to be created) and
  display nodes (i.e. the one we want to highlight while an ephemeral
  node is executing) separately.
This commit is contained in:
Jacob Segal 2023-07-30 22:10:28 -07:00
parent dbb5a3122a
commit a86d383ff3
4 changed files with 98 additions and 42 deletions

View File

@ -14,16 +14,37 @@ def is_link(obj):
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
def __init__(self, prefix = True):
if isinstance(prefix, str):
self.prefix = prefix
elif prefix:
self.prefix = "%d.%d." % (random.randint(0, 0xffffffffffffffff), random.randint(0, 0xffffffffffffffff))
_default_prefix_root = ""
_default_prefix_call_index = 0
_default_prefix_graph_index = 0
def __init__(self, prefix = None):
if prefix is None:
self.prefix = GraphBuilder.alloc_prefix()
else:
self.prefix = ""
self.prefix = prefix
self.nodes = {}
self.id_gen = 1
@classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0):
cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index
if graph_index is not None:
cls._default_prefix_graph_index = graph_index
@classmethod
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
if root is None:
root = GraphBuilder._default_prefix_root
if call_index is None:
call_index = GraphBuilder._default_prefix_call_index
if graph_index is None:
graph_index = GraphBuilder._default_prefix_graph_index
result = "%s.%d.%d." % (root, call_index, graph_index)
GraphBuilder._default_prefix_graph_index += 1
return result
def node(self, class_type, id=None, **kwargs):
if id is None:
id = str(self.id_gen)
@ -73,7 +94,7 @@ class Node:
self.id = id
self.class_type = class_type
self.inputs = inputs
self.override_parent_id = None
self.override_display_id = None
def out(self, index):
return [self.id, index]
@ -88,16 +109,16 @@ class Node:
def get_input(self, key):
return self.inputs.get(key)
def set_override_parent_id(self, override_parent_id):
self.override_parent_id = override_parent_id
def set_override_display_id(self, override_display_id):
self.override_display_id = override_display_id
def serialize(self):
serialized = {
"class_type": self.class_type,
"inputs": self.inputs
}
if self.override_parent_id is not None:
serialized["override_parent_id"] = self.override_parent_id
if self.override_display_id is not None:
serialized["override_display_id"] = self.override_display_id
return serialized
def add_graph_prefix(graph, outputs, prefix):

View File

@ -3,6 +3,7 @@ import shutil
import folder_paths
import json
import copy
import comfy.graph_utils
comfy_path = os.path.dirname(folder_paths.__file__)
js_path = os.path.join(comfy_path, "web", "extensions")
@ -192,8 +193,10 @@ def LoadComponent(component_file):
for input_node in component_inputs:
if input_node["name"] in kwargs:
new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]]
outputs = tuple([[node["node_id"], 0] for node in component_outputs])
new_graph, outputs = comfy.graph_utils.add_graph_prefix(new_graph, outputs, comfy.graph_utils.GraphBuilder.alloc_prefix())
return {
"result": tuple([[node["node_id"], 0] for node in component_outputs]),
"result": outputs,
"expand": new_graph,
}
ComponentNode.__name__ = component_raw_name

View File

@ -99,14 +99,16 @@ class WhileLoopClose:
contained[unique_id] = True
contained[open_node] = True
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
graph = GraphBuilder()
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.node(original_node["class_type"], node_id)
node.set_override_parent_id(node_id)
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
node.set_override_display_id(node_id)
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.lookup_node(node_id)
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
for k, v in original_node["inputs"].items():
if is_link(v) and v[0] in contained:
parent = graph.lookup_node(v[0])
@ -117,7 +119,7 @@ class WhileLoopClose:
for i in range(NUM_FLOW_SOCKETS):
key = "initial_value%d" % i
new_open.set_input(key, kwargs.get(key, None))
my_clone = graph.lookup_node(unique_id)
my_clone = graph.lookup_node("Recurse" )
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
return {
"result": tuple(result),

View File

@ -14,7 +14,7 @@ import nodes
import comfy.model_management
import comfy.graph_utils
from comfy.graph_utils import is_link, ExecutionBlocker
from comfy.graph_utils import is_link, ExecutionBlocker, GraphBuilder
class ExecutionResult(Enum):
SUCCESS = 0
@ -45,11 +45,9 @@ def get_input_info(class_def, input_name):
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added.
class ExecutionList:
def __init__(self, dynprompt, outputs):
class TopologicalSort:
def __init__(self, dynprompt):
self.dynprompt = dynprompt
self.outputs = outputs
self.staged_node_id = None
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
@ -70,9 +68,6 @@ class ExecutionList:
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if from_node_id in self.outputs:
# Nothing to do
return
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
@ -95,11 +90,35 @@ class ExecutionList:
if "lazy" not in input_info or not input_info["lazy"]:
self.add_strong_link(from_node_id, from_socket, unique_id)
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
for blocked_node_id in self.blocking[unique_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[unique_id]
def is_empty(self):
return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort):
def __init__(self, dynprompt, outputs):
super().__init__(dynprompt)
self.outputs = outputs
self.staged_node_id = None
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if from_node_id in self.outputs:
# Nothing to do
return
super().add_strong_link(from_node_id, from_socket, to_node_id)
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None
available = [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
available = self.get_ready_nodes()
if len(available) == 0:
raise Exception("Dependency cycle detected")
next_node = available[0]
@ -122,14 +141,9 @@ class ExecutionList:
def complete_node_execution(self):
node_id = self.staged_node_id
del self.pendingNodes[node_id]
for blocked_node_id in self.blocking[node_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[node_id]
self.pop_node(node_id)
self.staged_node_id = None
def is_empty(self):
return len(self.pendingNodes) == 0
class DynamicPrompt:
def __init__(self, original_prompt):
@ -138,6 +152,7 @@ class DynamicPrompt:
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
self.ephemeral_display = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
@ -146,7 +161,7 @@ class DynamicPrompt:
return self.original_prompt[node_id]
return None
def add_ephemeral_node(self, parent_id, node_id, node_info):
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = parent_id
@ -155,6 +170,14 @@ class DynamicPrompt:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@ -185,7 +208,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
input_data_all[x] = [unique_id]
return input_data_all
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
intput_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
@ -212,6 +235,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, executi
break
if execution_block is None:
if pre_execute_cb is not None:
pre_execute_cb(0)
results.append(getattr(obj, func)(**input_data_all))
else:
results.append(execution_block)
@ -226,6 +251,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, executi
execution_block = execution_block_cb(v) if execution_block_cb is not None else v
break
if execution_block is None:
if pre_execute_cb is not None:
pre_execute_cb(i)
results.append(getattr(obj, func)(**input_dict))
else:
results.append(execution_block)
@ -246,12 +273,12 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all, execution_block_cb=None):
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = []
uis = []
subgraph_results = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb)
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
@ -299,6 +326,7 @@ def format_value(x):
def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -331,8 +359,8 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
else:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = real_node_id
server.send_sync("executing", { "node": real_node_id, "prompt_id": prompt_id }, server.client_id)
server.last_node_id = display_node_id
server.send_sync("executing", { "node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
@ -366,11 +394,13 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
return ExecutionBlocker(None)
else:
return block
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb)
def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": real_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
for i in range(len(output_data)):
@ -381,12 +411,12 @@ def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data,
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.get_node(node_id) is not None:
new_graph, node_outputs = comfy.graph_utils.add_graph_prefix(new_graph, node_outputs, "%s.%d." % (unique_id, i))
raise Exception("Attempt to add duplicate node %s" % node_id)
break
new_output_ids = []
for node_id, node_info in new_graph.items():
parent_id = node_info.get("override_parent_id", real_node_id)
dynprompt.add_ephemeral_node(parent_id, node_id, node_info)
display_id = node_info.get("override_display_id", unique_id)
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]