Add lazy evaluation and dynamic node expansion

This PR inverts the execution model -- from recursively calling nodes to
using a topological sort of the nodes. This change allows for
modification of the node graph during execution. This allows for two
major advantages:
1. The implementation of lazy evaluation in nodes. For example, if a
   "Mix Images" node has a mix factor of exactly 0.0, the second image
   input doesn't even need to be evaluated (and visa-versa if the mix
   factor is 1.0).
2. Dynamic expansion of nodes. This allows for the creation of dynamic
   "node groups". Specifically, custom nodes can return subgraphs that
   replace the original node in the graph. This is an *incredibly*
   powerful concept. Using this functionality, it was easy to
   implement:
   a. Components (a.k.a. node groups)
   b. Flow control (i.e. while loops) via tail recursion
   c. All-in-one nodes that replicate the WebUI functionality
   d. and more
All of those were able to be implemented entirely via custom nodes
without hooking or replacing any core functionality. Within this PR,
I've included all of these proof-of-concepts within a custom node pack.
In reality, I would expect some number of them to be merged into the
core node set (with the rest left to be implemented by custom nodes).

I made very few changes to the front-end, so there are probably some
easy UX wins for someone who is more willing to wade into .js land. The
user experience is a lot better than I expected though -- progress shows
correctly in the UI over the nodes that are being expanded.
This commit is contained in:
Jacob Segal 2023-07-17 19:40:27 -07:00
parent 7e4bc4451b
commit b234baee2c
7 changed files with 1002 additions and 85 deletions

104
comfy/graph_utils.py Normal file
View File

@ -0,0 +1,104 @@
import json
import random
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder:
def __init__(self, prefix = True):
if isinstance(prefix, str):
self.prefix = prefix
elif prefix:
self.prefix = "%d.%d." % (random.randint(0, 0xffffffffffffffff), random.randint(0, 0xffffffffffffffff))
else:
self.prefix = ""
self.nodes = {}
self.id_gen = 1
def node(self, class_type, id=None, **kwargs):
if id is None:
id = str(self.id_gen)
self.id_gen += 1
id = self.prefix + id
if id in self.nodes:
return self.nodes[id]
node = Node(id, class_type, kwargs)
self.nodes[id] = node
return node
def lookup_node(self, id):
id = self.prefix + id
return self.nodes.get(id)
def finalize(self):
output = {}
for node_id, node in self.nodes.items():
output[node_id] = node.serialize()
return output
def replace_node_output(self, node_id, index, new_value):
node_id = self.prefix + node_id
to_remove = []
for node in self.nodes.values():
for key, value in node.inputs.items():
if isinstance(value, list) and value[0] == node_id and value[1] == index:
if new_value is None:
to_remove.append((node, key))
else:
node.inputs[key] = new_value
for node, key in to_remove:
del node.inputs[key]
def remove_node(self, id):
id = self.prefix + id
del self.nodes[id]
class Node:
def __init__(self, id, class_type, inputs):
self.id = id
self.class_type = class_type
self.inputs = inputs
def out(self, index):
return [self.id, index]
def set_input(self, key, value):
if value is None:
if key in self.inputs:
del self.inputs[key]
else:
self.inputs[key] = value
def get_input(self, key):
return self.inputs.get(key)
def serialize(self):
return {
"class_type": self.class_type,
"inputs": self.inputs
}
def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links
new_graph = {}
for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs
new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} }
for input_name, input_value in node_info.get("inputs", {}).items():
if isinstance(input_value, list):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
else:
new_node["inputs"][input_name] = input_value
new_graph[new_node_id] = new_node
# Change the node IDs in the outputs
new_outputs = []
for n in range(len(outputs)):
output = outputs[n]
if isinstance(output, list): # This is a node link
new_outputs.append([prefix + output[0], output[1]])
else:
new_outputs.append(output)
return new_graph, tuple(new_outputs)

View File

@ -0,0 +1,19 @@
from .nodes import GENERAL_NODE_CLASS_MAPPINGS, GENERAL_NODE_DISPLAY_NAME_MAPPINGS
from .components import setup_js, COMPONENT_NODE_CLASS_MAPPINGS, COMPONENT_NODE_DISPLAY_NAME_MAPPINGS
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS.update(GENERAL_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(GENERAL_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
setup_js()

View File

@ -0,0 +1,208 @@
import os
import shutil
import folder_paths
import json
import copy
comfy_path = os.path.dirname(folder_paths.__file__)
js_path = os.path.join(comfy_path, "web", "extensions")
inversion_demo_path = os.path.dirname(__file__)
def setup_js():
# setup js
js_dest_path = os.path.join(js_path, "inversion-demo-components")
if not os.path.exists(js_dest_path):
os.makedirs(js_dest_path)
js_src_path = os.path.join(inversion_demo_path, "js", "inversion-demo-components.js")
shutil.copy(js_src_path, js_dest_path)
class ComponentInput:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"name": ("STRING", {"multiline": False}),
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
"extra_args": ("STRING", {"multiline": False}),
"explicit_input_order": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"optional": ([False, True],),
},
"optional": {
"default_value": ("*",),
},
}
RETURN_TYPES = ("*",)
FUNCTION = "component_input"
CATEGORY = "Component Creation"
def component_input(self, name, data_type, extra_args, explicit_input_order, optional, default_value = None):
return (default_value,)
class ComponentOutput:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
"value": ("*",),
},
}
RETURN_TYPES = ("*",)
FUNCTION = "component_output"
CATEGORY = "Component Creation"
def component_output(self, index, data_type, value):
return (value,)
class ComponentMetadata:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"name": ("STRING", {"multiline": False}),
"always_output": ([False, True],),
},
}
RETURN_TYPES = ()
FUNCTION = "nop"
CATEGORY = "Component Creation"
def nop(self, name):
return {}
COMPONENT_NODE_CLASS_MAPPINGS = {
"ComponentInput": ComponentInput,
"ComponentOutput": ComponentOutput,
"ComponentMetadata": ComponentMetadata,
}
COMPONENT_NODE_DISPLAY_NAME_MAPPINGS = {
"ComponentInput": "Component Input",
"ComponentOutput": "Component Output",
"ComponentMetadata": "Component Metadata",
}
DEFAULT_EXTRA_DATA = {
"STRING": {"multiline": False},
"INT": {"default": 0, "min": 0, "max": 1000, "step": 1},
"FLOAT": {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1},
}
def default_extra_data(data_type, extra_args):
if data_type == "STRING":
args = {"multiline": False}
elif data_type == "INT":
args = {"default": 0, "min": -1000000, "max": 1000000, "step": 1}
elif data_type == "FLOAT":
args = {"default": 0.0, "min": -1000000.0, "max": 1000000.0, "step": 0.1}
else:
args = {}
args.update(extra_args)
return args
def LoadComponent(component_file):
try:
with open(component_file, "r") as f:
component_data = f.read()
graph = json.loads(component_data)["output"]
component_raw_name = os.path.basename(component_file).split(".")[0]
component_display_name = component_raw_name
component_inputs = []
component_outputs = []
is_output_component = False
for node_id, data in graph.items():
if data["class_type"] == "ComponentMetadata":
component_display_name = data["inputs"].get("name", component_raw_name)
is_output_component = data["inputs"].get("always_output", False)
elif data["class_type"] == "ComponentInput":
data_type = data["inputs"]["data_type"]
if len(data_type) > 0 and data_type[0] == "[":
try:
data_type = json.loads(data_type)
except:
pass
try:
extra_args = json.loads(data["inputs"]["extra_args"])
except:
extra_args = {}
component_inputs.append({
"node_id": node_id,
"name": data["inputs"]["name"],
"data_type": data_type,
"extra_args": extra_args,
"explicit_input_order": data["inputs"]["explicit_input_order"],
"optional": data["inputs"]["optional"],
})
elif data["class_type"] == "ComponentOutput":
component_outputs.append({
"node_id": node_id,
"index": data["inputs"]["index"],
"data_type": data["inputs"]["data_type"],
})
component_inputs.sort(key=lambda x: (x["explicit_input_order"], x["name"]))
component_outputs.sort(key=lambda x: x["index"])
for i in range(1, len(component_inputs)):
if component_inputs[i]["name"] == component_inputs[i-1]["name"]:
raise Exception("Component input name is not unique: {}".format(component_inputs[i]["name"]))
for i in range(1, len(component_outputs)):
if component_outputs[i]["index"] == component_outputs[i-1]["index"]:
raise Exception("Component output index is not unique: {}".format(component_outputs[i]["index"]))
except Exception as e:
print("Error loading component file: {}: {}".format(component_file, e))
return None
class ComponentNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if not node["optional"]},
"optional": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if node["optional"]},
}
RETURN_TYPES = tuple([node["data_type"] for node in component_outputs])
FUNCTION = "expand_component"
CATEGORY = "Custom Components"
OUTPUT_NODE = is_output_component
def expand_component(self, **kwargs):
new_graph = copy.deepcopy(graph)
for input_node in component_inputs:
if input_node["name"] in kwargs:
new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]]
return {
"result": tuple([[node["node_id"], 0] for node in component_outputs]),
"expand": new_graph,
}
ComponentNode.__name__ = component_raw_name
COMPONENT_NODE_CLASS_MAPPINGS[component_raw_name] = ComponentNode
COMPONENT_NODE_DISPLAY_NAME_MAPPINGS[component_raw_name] = component_display_name
print("Loaded component: {}".format(component_display_name))
def load_components():
component_dir = os.path.join(comfy_path, "components")
files = [f for f in os.listdir(component_dir) if os.path.isfile(os.path.join(component_dir, f)) and f.endswith(".json")]
for f in files:
print("Loading component file %s" % f)
LoadComponent(os.path.join(component_dir, f))
load_components()

View File

@ -0,0 +1,131 @@
from comfy.graph_utils import GraphBuilder
NUM_FLOW_SOCKETS = 5
class WhileLoopOpen:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"condition": ("INT", {"default": 1, "min": 0, "max": 1, "step": 1}),
},
"optional": {
},
}
for i in range(NUM_FLOW_SOCKETS):
inputs["optional"]["initial_value%d" % i] = ("*",)
return inputs
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
FUNCTION = "while_loop_open"
CATEGORY = "Flow Control"
def while_loop_open(self, condition, **kwargs):
values = []
for i in range(NUM_FLOW_SOCKETS):
values.append(kwargs.get("initial_value%d" % i, None))
return tuple(["stub"] + values)
class WhileLoopClose:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"flow_control": ("FLOW_CONTROL",),
"condition": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1}),
},
"optional": {
},
"hidden": {
"dynprompt": "DYNPROMPT",
"unique_id": "UNIQUE_ID",
}
}
for i in range(NUM_FLOW_SOCKETS):
inputs["optional"]["initial_value%d" % i] = ("*",)
return inputs
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
FUNCTION = "while_loop_close"
CATEGORY = "Flow Control"
def explore_dependencies(self, node_id, dynprompt, upstream):
node_info = dynprompt.get_node(node_id)
if "inputs" not in node_info:
return
for k, v in node_info["inputs"].items():
if isinstance(v, list) and len(v) == 2:
parent_id = v[0]
if parent_id not in upstream:
upstream[parent_id] = []
self.explore_dependencies(parent_id, dynprompt, upstream)
upstream[parent_id].append(node_id)
def collect_contained(self, node_id, upstream, contained):
if node_id not in upstream:
return
for child_id in upstream[node_id]:
if child_id not in contained:
contained[child_id] = True
self.collect_contained(child_id, upstream, contained)
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
if not condition:
# We're done with the loop
values = []
for i in range(NUM_FLOW_SOCKETS):
values.append(kwargs.get("initial_value%d" % i, None))
return tuple(values)
# We want to loop
this_node = dynprompt.get_node(unique_id)
upstream = {}
# Get the list of all nodes between the open and close nodes
self.explore_dependencies(unique_id, dynprompt, upstream)
contained = {}
open_node = this_node["inputs"]["flow_control"][0]
self.collect_contained(open_node, upstream, contained)
contained[unique_id] = True
contained[open_node] = True
graph = GraphBuilder()
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.node(original_node["class_type"], node_id)
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.lookup_node(node_id)
for k, v in original_node["inputs"].items():
if isinstance(v, list) and len(v) == 2 and v[0] in contained:
parent = graph.lookup_node(v[0])
node.set_input(k, parent.out(v[1]))
else:
node.set_input(k, v)
new_open = graph.lookup_node(open_node)
for i in range(NUM_FLOW_SOCKETS):
key = "initial_value%d" % i
new_open.set_input(key, kwargs.get(key, None))
my_clone = graph.lookup_node(unique_id)
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
return {
"result": tuple(result),
"expand": graph.finalize(),
}
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
"WhileLoopOpen": WhileLoopOpen,
"WhileLoopClose": WhileLoopClose,
}
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
"WhileLoopOpen": "While Loop Open",
"WhileLoopClose": "While Loop Close",
}

View File

@ -0,0 +1,64 @@
import { app } from "/scripts/app.js";
import { ComfyDialog, $el } from "/scripts/ui.js";
import {ComfyWidgets} from "../../scripts/widgets.js";
var update_comfyui_button = null;
var fetch_updates_button = null;
const fileInput = $el("input", {
id: "component-file-input",
type: "file",
accept: ".json,image/png,.latent,.safetensors",
style: {display: "none"},
parent: document.body,
onchange: async () => {
app.handleFile(fileInput.files[0]);
const reader = new FileReader();
reader.onload = () => {
app.loadGraphData(JSON.parse(reader.result)["workflow"]);
};
reader.readAsText(fileInput.files[0]);
},
});
app.registerExtension({
name: "Comfy.InversionDemoComponents",
async setup() {
const menu = document.querySelector(".comfy-menu");
const separator = document.createElement("hr");
separator.style.margin = "20px 0";
separator.style.width = "100%";
menu.append(separator);
const saveButton = document.createElement("button");
saveButton.textContent = "Save Component";
saveButton.onclick = async () => {
let filename = "component.json";
const p = await app.graphToPrompt();
const json = JSON.stringify(p, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: filename,
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
};
const loadButton = document.createElement("button");
loadButton.textContent = "Load Component";
loadButton.onclick = () => {
fileInput.click();
};
menu.append(saveButton);
menu.append(loadButton);
}
});

View File

@ -0,0 +1,206 @@
import re
from comfy.graph_utils import GraphBuilder
class InversionDemoAdvancedPromptNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"multiline": True}),
"model": ("MODEL",),
"clip": ("CLIP",),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING")
FUNCTION = "advanced_prompt"
CATEGORY = "InversionDemo Nodes"
def parse_prompt(self, prompt):
# Get all string pieces matching the pattern "<lora:(name):(strength)(:(clip_strength))?>"
# where name is a string and strength is a float
# and clip_strength is an optional float
pattern = r"<lora:([^:]+):([-0-9.]+)(?::([-0-9.]+))?>"
loras = re.findall(pattern, prompt)
if len(loras) == 0:
return prompt, loras
cleaned_prompt = re.sub(pattern, "", prompt).strip()
print("Cleaned prompt: '%s'" % cleaned_prompt)
return cleaned_prompt, loras
def advanced_prompt(self, prompt, clip, model):
cleaned_prompt, loras = self.parse_prompt(prompt)
graph = GraphBuilder()
for lora in loras:
lora_name = lora[0]
lora_model_strength = float(lora[1])
lora_clip_strength = lora_model_strength if lora[2] == "" else float(lora[2])
loader = graph.node("LoraLoader", model=model, clip=clip, lora_name = lora_name, strength_model = lora_model_strength, strength_clip = lora_clip_strength)
model = loader.out(0)
clip = loader.out(1)
encoder = graph.node("CLIPTextEncode", clip=clip, text=cleaned_prompt)
return {
"result": (model, clip, encoder.out(0)),
"expand": graph.finalize(),
}
class InversionDemoFakeAdvancedPromptNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"multiline": True}),
"clip": ("CLIP",),
"model": ("MODEL",),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING")
FUNCTION = "advanced_prompt"
CATEGORY = "InversionDemo Nodes"
def advanced_prompt(self, prompt, clip, model):
tokens = clip.tokenize(prompt)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
return (model, clip, [[cond, {"pooled_output": pooled}]])
class InversionDemoLazySwitch:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"switch": ([False, True],),
"on_false": ("*", {"lazy": True}),
"on_true": ("*", {"lazy": True}),
},
}
RETURN_TYPES = ("*",)
FUNCTION = "switch"
CATEGORY = "InversionDemo Nodes"
def check_lazy_status(self, switch, on_false = None, on_true = None):
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
def switch(self, switch, on_false = None, on_true = None):
value = on_true if switch else on_false
return (value,)
class InversionDemoLazyIndexSwitch:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"index": ("INT", {"default": 0, "min": 0, "max": 9, "step": 1}),
"value0": ("*", {"lazy": True}),
},
"optional": {
"value1": ("*", {"lazy": True}),
"value2": ("*", {"lazy": True}),
"value3": ("*", {"lazy": True}),
"value4": ("*", {"lazy": True}),
"value5": ("*", {"lazy": True}),
"value6": ("*", {"lazy": True}),
"value7": ("*", {"lazy": True}),
"value8": ("*", {"lazy": True}),
"value9": ("*", {"lazy": True}),
}
}
RETURN_TYPES = ("*",)
FUNCTION = "index_switch"
CATEGORY = "InversionDemo Nodes"
def check_lazy_status(self, index, **kwargs):
key = "value%d" % index
if key not in kwargs:
return [key]
def index_switch(self, index, **kwargs):
key = "value%d" % index
return kwargs[key]
class InversionDemoLazyMixImages:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE",{"lazy": True}),
"image2": ("IMAGE",{"lazy": True}),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "mix"
CATEGORY = "InversionDemo Nodes"
def check_lazy_status(self, mask, image1 = None, image2 = None):
mask_min = mask.min()
mask_max = mask.max()
needed = []
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
needed.append("image1")
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
needed.append("image2")
return needed
# Not trying to handle different batch sizes here just to keep the demo simple
def mix(self, mask, image1 = None, image2 = None):
mask_min = mask.min()
mask_max = mask.max()
if mask_min == 0.0 and mask_max == 0.0:
return (image1,)
elif mask_min == 1.0 and mask_max == 1.0:
return (image2,)
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if len(mask.shape) == 3:
mask = mask.unsqueeze(3)
if mask.shape[3] < image1.shape[3]:
mask = mask.repeat(1, 1, 1, image1.shape[3])
return (image1 * (1. - mask) + image2 * mask,)
GENERAL_NODE_CLASS_MAPPINGS = {
"InversionDemoAdvancedPromptNode": InversionDemoAdvancedPromptNode,
"InversionDemoFakeAdvancedPromptNode": InversionDemoFakeAdvancedPromptNode,
"InversionDemoLazySwitch": InversionDemoLazySwitch,
"InversionDemoLazyIndexSwitch": InversionDemoLazyIndexSwitch,
"InversionDemoLazyMixImages": InversionDemoLazyMixImages,
}
GENERAL_NODE_DISPLAY_NAME_MAPPINGS = {
"InversionDemoAdvancedPromptNode": "Advanced Prompt",
"InversionDemoFakeAdvancedPromptNode": "Fake Advanced Prompt",
"InversionDemoLazySwitch": "Lazy Switch",
"InversionDemoLazyIndexSwitch": "Lazy Index Switch",
"InversionDemoLazyMixImages": "Lazy Mix Images",
}

View File

@ -7,13 +7,150 @@ import heapq
import traceback
import gc
import time
from enum import Enum
import torch
import nodes
import comfy.model_management
import comfy.graph_utils
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
class ExecutionResult(Enum):
SUCCESS = 0
FAILURE = 1
SLEEPING = 2
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added.
class ExecutionList:
def __init__(self, dynprompt, outputs):
self.dynprompt = dynprompt
self.outputs = outputs
self.staged_node_id = None
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
extra_info = None
if len(input_info) > 1:
extra_info = input_info[1]
return input_type, input_category, extra_info
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
if to_input not in inputs:
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
value = inputs[to_input]
if not isinstance(value, list):
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
from_node_id, from_socket = value
self.add_strong_link(from_node_id, from_socket, to_node_id)
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if from_node_id in self.outputs:
# Nothing to do
return
self.add_node(from_node_id)
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, unique_id):
if unique_id in self.pendingNodes:
return
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if isinstance(value, list):
from_node_id, from_socket = value
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
if input_info is None or "lazy" not in input_info or not input_info["lazy"]:
self.add_strong_link(from_node_id, from_socket, unique_id)
def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None
available = [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
if len(available) == 0:
raise Exception("Dependency cycle detected")
next_node = available[0]
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
for node_id in available:
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
next_node = node_id
break
self.staged_node_id = next_node
return self.staged_node_id
def unstage_node_execution(self):
assert self.staged_node_id is not None
self.staged_node_id = None
def complete_node_execution(self):
node_id = self.staged_node_id
del self.pendingNodes[node_id]
for blocked_node_id in self.blocking[node_id]:
self.blockCount[blocked_node_id] -= 1
del self.blocking[node_id]
self.staged_node_id = None
def is_empty(self):
return len(self.pendingNodes) == 0
class DynamicPrompt:
def __init__(self, original_prompt):
# The original prompt provided by the user
self.original_prompt = original_prompt
# Any extra pieces of the graph created during execution
self.ephemeral_prompt = {}
self.ephemeral_parents = {}
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
return None
def add_ephemeral_node(self, real_parent_id, node_id, node_info):
self.ephemeral_prompt[node_id] = node_info
self.ephemeral_parents[node_id] = real_parent_id
def get_real_node_id(self, node_id):
if node_id in self.ephemeral_parents:
return self.ephemeral_parents[node_id]
return node_id
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
for x in inputs:
@ -22,7 +159,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
return None
continue # This might be a lazily-evaluated input
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
else:
@ -34,6 +171,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
for x in h:
if h[x] == "PROMPT":
input_data_all[x] = [prompt]
if h[x] == "DYNPROMPT":
input_data_all[x] = [dynprompt]
if h[x] == "EXTRA_PNGINFO":
if "extra_pnginfo" in extra_data:
input_data_all[x] = [extra_data['extra_pnginfo']]
@ -68,39 +207,54 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
return results
def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
return output
def get_output_data(obj, input_data_all):
results = []
uis = []
subgraph_results = []
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
for r in return_values:
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
if isinstance(r, dict):
if 'ui' in r:
uis.append(r['ui'])
if 'result' in r:
if 'expand' in r:
# Perform an expansion, but do not append results
has_subgraph = True
new_graph = r['expand']
subgraph_results.append((new_graph, r.get("result", None)))
elif 'result' in r:
results.append(r['result'])
subgraph_results.append((None, r['result']))
else:
results.append(r)
output = []
if len(results) > 0:
# check which outputs need concatenating
output_is_list = [False] * len(results[0])
if hasattr(obj, "OUTPUT_IS_LIST"):
output_is_list = obj.OUTPUT_IS_LIST
# merge node execution results
for i, is_list in zip(range(len(results[0])), output_is_list):
if is_list:
output.append([x for o in results for x in o[i]])
else:
output.append([o[i] for o in results])
if has_subgraph:
output = subgraph_results
elif len(results) > 0:
output = merge_result_data(results, obj)
else:
output = []
ui = dict()
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui
return output, ui, has_subgraph
def format_value(x):
if x is None:
@ -110,53 +264,102 @@ def format_value(x):
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
real_node_id = dynprompt.get_real_node_id(unique_id)
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs:
return (True, None, None)
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
if result[0] is not True:
# Another node failed further upstream
return result
return (ExecutionResult.SUCCESS, None, None)
input_data_all = None
try:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
if unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
if not is_subgraph:
resolved_outputs.append(result)
else:
resolved_output = []
for r in result:
if isinstance(r, list) and len(r) == 2:
source_node, source_output = r[0], r[1]
node_output = outputs[source_node][source_output]
for o in node_output:
resolved_output.append(o)
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
else:
resolved_output.append(r)
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
has_subgraph = False
else:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = real_node_id
server.send_sync("executing", { "node": real_node_id, "prompt_id": prompt_id }, server.client_id)
output_data, output_ui = get_output_data(obj, input_data_all)
outputs[unique_id] = output_data
obj = object_storage.get((unique_id, class_type), None)
if obj is None:
obj = class_def()
object_storage[(unique_id, class_type)] = obj
if hasattr(obj, "check_lazy_status"):
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if x not in input_data_all]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.SLEEPING, None, None)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all)
if len(output_ui) > 0:
outputs_ui[unique_id] = output_ui
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
server.send_sync("executed", { "node": real_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph:
cached_outputs = []
for i in range(len(output_data)):
new_graph, node_outputs = output_data[i]
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.get_node(node_id) is not None:
new_graph, node_outputs = comfy.graph_utils.add_graph_prefix(new_graph, node_outputs, "%s.%d." % (unique_id, i))
break
new_output_ids = []
for node_id, node_info in new_graph.items():
dynprompt.add_ephemeral_node(real_node_id, node_id, node_info)
# Figure out if the newly created node is an output node
class_type = node_info["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
new_output_ids.append(node_id)
for node_id in new_output_ids:
execution_list.add_node(node_id)
for i in range(len(node_outputs)):
if isinstance(node_outputs[i], list) and len(node_outputs[i]) == 2:
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
execution_list.add_strong_link(from_node_id, from_socket, unique_id)
cached_outputs.append((True, node_outputs))
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.SLEEPING, None, None)
outputs[unique_id] = output_data
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
}
return (False, error_details, iex)
return (ExecutionResult.FAILURE, error_details, iex)
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
@ -174,35 +377,18 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
print(traceback.format_exc())
error_details = {
"node_id": unique_id,
"node_id": real_node_id,
"exception_message": str(ex),
"exception_type": exception_type,
"traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
}
return (False, error_details, ex)
return (ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id)
return (True, None, None)
def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
will_execute = []
if unique_id in outputs:
return []
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
return will_execute + [unique_id]
return (ExecutionResult.SUCCESS, None, None)
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
unique_id = current_item
@ -350,28 +536,27 @@ class PromptExecutor:
if self.server.client_id is not None:
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
pending_subgraph_results = {}
dynamic_prompt = DynamicPrompt(prompt)
executed = set()
output_node_id = None
to_execute = []
execution_list = ExecutionList(dynamic_prompt, self.outputs)
for node_id in list(execute_outputs):
to_execute += [(0, node_id)]
execution_list.add_node(node_id)
while len(to_execute) > 0:
#always execute the output that depends on the least amount of unexecuted nodes first
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
output_node_id = to_execute.pop(0)[-1]
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
while not execution_list.is_empty():
node_id = execution_list.stage_node_execution()
result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.outputs, node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage, execution_list, pending_subgraph_results)
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break
elif result == ExecutionResult.SLEEPING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS
execution_list.complete_node_execution()
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
if x in prompt:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None