ComfyUI/custom_nodes/execution-inversion-demo-comfyui/components.py
Jacob Segal b234baee2c Add lazy evaluation and dynamic node expansion
This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:
1. The implementation of lazy evaluation in nodes. For example, if a
   "Mix Images" node has a mix factor of exactly 0.0, the second image
   input doesn't even need to be evaluated (and visa-versa if the mix
   factor is 1.0).
2. Dynamic expansion of nodes. This allows for the creation of dynamic
   "node groups". Specifically, custom nodes can return subgraphs that
   replace the original node in the graph. This is an *incredibly*
   powerful concept. Using this functionality, it was easy to
   implement:
   a. Components (a.k.a. node groups)
   b. Flow control (i.e. while loops) via tail recursion
   c. All-in-one nodes that replicate the WebUI functionality
   d. and more
All of those were able to be implemented entirely via custom nodes
without hooking or replacing any core functionality. Within this PR,
I've included all of these proof-of-concepts within a custom node pack.
In reality, I would expect some number of them to be merged into the
core node set (with the rest left to be implemented by custom nodes).

I made very few changes to the front-end, so there are probably some
easy UX wins for someone who is more willing to wade into .js land. The
user experience is a lot better than I expected though -- progress shows
correctly in the UI over the nodes that are being expanded.
2023-07-18 20:08:12 -07:00

209 lines
7.4 KiB
Python

import os
import shutil
import folder_paths
import json
import copy
comfy_path = os.path.dirname(folder_paths.__file__)
js_path = os.path.join(comfy_path, "web", "extensions")
inversion_demo_path = os.path.dirname(__file__)
def setup_js():
# setup js
js_dest_path = os.path.join(js_path, "inversion-demo-components")
if not os.path.exists(js_dest_path):
os.makedirs(js_dest_path)
js_src_path = os.path.join(inversion_demo_path, "js", "inversion-demo-components.js")
shutil.copy(js_src_path, js_dest_path)
class ComponentInput:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"name": ("STRING", {"multiline": False}),
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
"extra_args": ("STRING", {"multiline": False}),
"explicit_input_order": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"optional": ([False, True],),
},
"optional": {
"default_value": ("*",),
},
}
RETURN_TYPES = ("*",)
FUNCTION = "component_input"
CATEGORY = "Component Creation"
def component_input(self, name, data_type, extra_args, explicit_input_order, optional, default_value = None):
return (default_value,)
class ComponentOutput:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
"value": ("*",),
},
}
RETURN_TYPES = ("*",)
FUNCTION = "component_output"
CATEGORY = "Component Creation"
def component_output(self, index, data_type, value):
return (value,)
class ComponentMetadata:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"name": ("STRING", {"multiline": False}),
"always_output": ([False, True],),
},
}
RETURN_TYPES = ()
FUNCTION = "nop"
CATEGORY = "Component Creation"
def nop(self, name):
return {}
COMPONENT_NODE_CLASS_MAPPINGS = {
"ComponentInput": ComponentInput,
"ComponentOutput": ComponentOutput,
"ComponentMetadata": ComponentMetadata,
}
COMPONENT_NODE_DISPLAY_NAME_MAPPINGS = {
"ComponentInput": "Component Input",
"ComponentOutput": "Component Output",
"ComponentMetadata": "Component Metadata",
}
DEFAULT_EXTRA_DATA = {
"STRING": {"multiline": False},
"INT": {"default": 0, "min": 0, "max": 1000, "step": 1},
"FLOAT": {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1},
}
def default_extra_data(data_type, extra_args):
if data_type == "STRING":
args = {"multiline": False}
elif data_type == "INT":
args = {"default": 0, "min": -1000000, "max": 1000000, "step": 1}
elif data_type == "FLOAT":
args = {"default": 0.0, "min": -1000000.0, "max": 1000000.0, "step": 0.1}
else:
args = {}
args.update(extra_args)
return args
def LoadComponent(component_file):
try:
with open(component_file, "r") as f:
component_data = f.read()
graph = json.loads(component_data)["output"]
component_raw_name = os.path.basename(component_file).split(".")[0]
component_display_name = component_raw_name
component_inputs = []
component_outputs = []
is_output_component = False
for node_id, data in graph.items():
if data["class_type"] == "ComponentMetadata":
component_display_name = data["inputs"].get("name", component_raw_name)
is_output_component = data["inputs"].get("always_output", False)
elif data["class_type"] == "ComponentInput":
data_type = data["inputs"]["data_type"]
if len(data_type) > 0 and data_type[0] == "[":
try:
data_type = json.loads(data_type)
except:
pass
try:
extra_args = json.loads(data["inputs"]["extra_args"])
except:
extra_args = {}
component_inputs.append({
"node_id": node_id,
"name": data["inputs"]["name"],
"data_type": data_type,
"extra_args": extra_args,
"explicit_input_order": data["inputs"]["explicit_input_order"],
"optional": data["inputs"]["optional"],
})
elif data["class_type"] == "ComponentOutput":
component_outputs.append({
"node_id": node_id,
"index": data["inputs"]["index"],
"data_type": data["inputs"]["data_type"],
})
component_inputs.sort(key=lambda x: (x["explicit_input_order"], x["name"]))
component_outputs.sort(key=lambda x: x["index"])
for i in range(1, len(component_inputs)):
if component_inputs[i]["name"] == component_inputs[i-1]["name"]:
raise Exception("Component input name is not unique: {}".format(component_inputs[i]["name"]))
for i in range(1, len(component_outputs)):
if component_outputs[i]["index"] == component_outputs[i-1]["index"]:
raise Exception("Component output index is not unique: {}".format(component_outputs[i]["index"]))
except Exception as e:
print("Error loading component file: {}: {}".format(component_file, e))
return None
class ComponentNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if not node["optional"]},
"optional": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if node["optional"]},
}
RETURN_TYPES = tuple([node["data_type"] for node in component_outputs])
FUNCTION = "expand_component"
CATEGORY = "Custom Components"
OUTPUT_NODE = is_output_component
def expand_component(self, **kwargs):
new_graph = copy.deepcopy(graph)
for input_node in component_inputs:
if input_node["name"] in kwargs:
new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]]
return {
"result": tuple([[node["node_id"], 0] for node in component_outputs]),
"expand": new_graph,
}
ComponentNode.__name__ = component_raw_name
COMPONENT_NODE_CLASS_MAPPINGS[component_raw_name] = ComponentNode
COMPONENT_NODE_DISPLAY_NAME_MAPPINGS[component_raw_name] = component_display_name
print("Loaded component: {}".format(component_display_name))
def load_components():
component_dir = os.path.join(comfy_path, "components")
files = [f for f in os.listdir(component_dir) if os.path.isfile(os.path.join(component_dir, f)) and f.endswith(".json")]
for f in files:
print("Loading component file %s" % f)
LoadComponent(os.path.join(component_dir, f))
load_components()