mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
Add lazy evaluation and dynamic node expansion
This PR inverts the execution model -- from recursively calling nodes to using a topological sort of the nodes. This change allows for modification of the node graph during execution. This allows for two major advantages: 1. The implementation of lazy evaluation in nodes. For example, if a "Mix Images" node has a mix factor of exactly 0.0, the second image input doesn't even need to be evaluated (and visa-versa if the mix factor is 1.0). 2. Dynamic expansion of nodes. This allows for the creation of dynamic "node groups". Specifically, custom nodes can return subgraphs that replace the original node in the graph. This is an *incredibly* powerful concept. Using this functionality, it was easy to implement: a. Components (a.k.a. node groups) b. Flow control (i.e. while loops) via tail recursion c. All-in-one nodes that replicate the WebUI functionality d. and more All of those were able to be implemented entirely via custom nodes without hooking or replacing any core functionality. Within this PR, I've included all of these proof-of-concepts within a custom node pack. In reality, I would expect some number of them to be merged into the core node set (with the rest left to be implemented by custom nodes). I made very few changes to the front-end, so there are probably some easy UX wins for someone who is more willing to wade into .js land. The user experience is a lot better than I expected though -- progress shows correctly in the UI over the nodes that are being expanded.
This commit is contained in:
parent
7e4bc4451b
commit
b234baee2c
104
comfy/graph_utils.py
Normal file
104
comfy/graph_utils.py
Normal file
@ -0,0 +1,104 @@
|
||||
import json
|
||||
import random
|
||||
|
||||
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||
class GraphBuilder:
|
||||
def __init__(self, prefix = True):
|
||||
if isinstance(prefix, str):
|
||||
self.prefix = prefix
|
||||
elif prefix:
|
||||
self.prefix = "%d.%d." % (random.randint(0, 0xffffffffffffffff), random.randint(0, 0xffffffffffffffff))
|
||||
else:
|
||||
self.prefix = ""
|
||||
self.nodes = {}
|
||||
self.id_gen = 1
|
||||
|
||||
def node(self, class_type, id=None, **kwargs):
|
||||
if id is None:
|
||||
id = str(self.id_gen)
|
||||
self.id_gen += 1
|
||||
id = self.prefix + id
|
||||
if id in self.nodes:
|
||||
return self.nodes[id]
|
||||
|
||||
node = Node(id, class_type, kwargs)
|
||||
self.nodes[id] = node
|
||||
return node
|
||||
|
||||
def lookup_node(self, id):
|
||||
id = self.prefix + id
|
||||
return self.nodes.get(id)
|
||||
|
||||
def finalize(self):
|
||||
output = {}
|
||||
for node_id, node in self.nodes.items():
|
||||
output[node_id] = node.serialize()
|
||||
return output
|
||||
|
||||
def replace_node_output(self, node_id, index, new_value):
|
||||
node_id = self.prefix + node_id
|
||||
to_remove = []
|
||||
for node in self.nodes.values():
|
||||
for key, value in node.inputs.items():
|
||||
if isinstance(value, list) and value[0] == node_id and value[1] == index:
|
||||
if new_value is None:
|
||||
to_remove.append((node, key))
|
||||
else:
|
||||
node.inputs[key] = new_value
|
||||
for node, key in to_remove:
|
||||
del node.inputs[key]
|
||||
|
||||
def remove_node(self, id):
|
||||
id = self.prefix + id
|
||||
del self.nodes[id]
|
||||
|
||||
class Node:
|
||||
def __init__(self, id, class_type, inputs):
|
||||
self.id = id
|
||||
self.class_type = class_type
|
||||
self.inputs = inputs
|
||||
|
||||
def out(self, index):
|
||||
return [self.id, index]
|
||||
|
||||
def set_input(self, key, value):
|
||||
if value is None:
|
||||
if key in self.inputs:
|
||||
del self.inputs[key]
|
||||
else:
|
||||
self.inputs[key] = value
|
||||
|
||||
def get_input(self, key):
|
||||
return self.inputs.get(key)
|
||||
|
||||
def serialize(self):
|
||||
return {
|
||||
"class_type": self.class_type,
|
||||
"inputs": self.inputs
|
||||
}
|
||||
|
||||
def add_graph_prefix(graph, outputs, prefix):
|
||||
# Change the node IDs and any internal links
|
||||
new_graph = {}
|
||||
for node_id, node_info in graph.items():
|
||||
# Make sure the added nodes have unique IDs
|
||||
new_node_id = prefix + node_id
|
||||
new_node = { "class_type": node_info["class_type"], "inputs": {} }
|
||||
for input_name, input_value in node_info.get("inputs", {}).items():
|
||||
if isinstance(input_value, list):
|
||||
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
||||
else:
|
||||
new_node["inputs"][input_name] = input_value
|
||||
new_graph[new_node_id] = new_node
|
||||
|
||||
# Change the node IDs in the outputs
|
||||
new_outputs = []
|
||||
for n in range(len(outputs)):
|
||||
output = outputs[n]
|
||||
if isinstance(output, list): # This is a node link
|
||||
new_outputs.append([prefix + output[0], output[1]])
|
||||
else:
|
||||
new_outputs.append(output)
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
|
||||
19
custom_nodes/execution-inversion-demo-comfyui/__init__.py
Normal file
19
custom_nodes/execution-inversion-demo-comfyui/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
from .nodes import GENERAL_NODE_CLASS_MAPPINGS, GENERAL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .components import setup_js, COMPONENT_NODE_CLASS_MAPPINGS, COMPONENT_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(GENERAL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(GENERAL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
setup_js()
|
||||
|
||||
208
custom_nodes/execution-inversion-demo-comfyui/components.py
Normal file
208
custom_nodes/execution-inversion-demo-comfyui/components.py
Normal file
@ -0,0 +1,208 @@
|
||||
import os
|
||||
import shutil
|
||||
import folder_paths
|
||||
import json
|
||||
import copy
|
||||
|
||||
comfy_path = os.path.dirname(folder_paths.__file__)
|
||||
js_path = os.path.join(comfy_path, "web", "extensions")
|
||||
inversion_demo_path = os.path.dirname(__file__)
|
||||
|
||||
def setup_js():
|
||||
# setup js
|
||||
js_dest_path = os.path.join(js_path, "inversion-demo-components")
|
||||
if not os.path.exists(js_dest_path):
|
||||
os.makedirs(js_dest_path)
|
||||
js_src_path = os.path.join(inversion_demo_path, "js", "inversion-demo-components.js")
|
||||
shutil.copy(js_src_path, js_dest_path)
|
||||
|
||||
class ComponentInput:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"name": ("STRING", {"multiline": False}),
|
||||
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
|
||||
"extra_args": ("STRING", {"multiline": False}),
|
||||
"explicit_input_order": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
|
||||
"optional": ([False, True],),
|
||||
},
|
||||
"optional": {
|
||||
"default_value": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "component_input"
|
||||
|
||||
CATEGORY = "Component Creation"
|
||||
|
||||
def component_input(self, name, data_type, extra_args, explicit_input_order, optional, default_value = None):
|
||||
return (default_value,)
|
||||
|
||||
class ComponentOutput:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
|
||||
"data_type": ("STRING", {"multiline": False, "default": "IMAGE"}),
|
||||
"value": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "component_output"
|
||||
|
||||
CATEGORY = "Component Creation"
|
||||
|
||||
def component_output(self, index, data_type, value):
|
||||
return (value,)
|
||||
|
||||
class ComponentMetadata:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"name": ("STRING", {"multiline": False}),
|
||||
"always_output": ([False, True],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "nop"
|
||||
|
||||
CATEGORY = "Component Creation"
|
||||
|
||||
def nop(self, name):
|
||||
return {}
|
||||
|
||||
COMPONENT_NODE_CLASS_MAPPINGS = {
|
||||
"ComponentInput": ComponentInput,
|
||||
"ComponentOutput": ComponentOutput,
|
||||
"ComponentMetadata": ComponentMetadata,
|
||||
}
|
||||
COMPONENT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ComponentInput": "Component Input",
|
||||
"ComponentOutput": "Component Output",
|
||||
"ComponentMetadata": "Component Metadata",
|
||||
}
|
||||
|
||||
DEFAULT_EXTRA_DATA = {
|
||||
"STRING": {"multiline": False},
|
||||
"INT": {"default": 0, "min": 0, "max": 1000, "step": 1},
|
||||
"FLOAT": {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1},
|
||||
}
|
||||
|
||||
def default_extra_data(data_type, extra_args):
|
||||
if data_type == "STRING":
|
||||
args = {"multiline": False}
|
||||
elif data_type == "INT":
|
||||
args = {"default": 0, "min": -1000000, "max": 1000000, "step": 1}
|
||||
elif data_type == "FLOAT":
|
||||
args = {"default": 0.0, "min": -1000000.0, "max": 1000000.0, "step": 0.1}
|
||||
else:
|
||||
args = {}
|
||||
args.update(extra_args)
|
||||
return args
|
||||
|
||||
def LoadComponent(component_file):
|
||||
try:
|
||||
with open(component_file, "r") as f:
|
||||
component_data = f.read()
|
||||
graph = json.loads(component_data)["output"]
|
||||
|
||||
component_raw_name = os.path.basename(component_file).split(".")[0]
|
||||
component_display_name = component_raw_name
|
||||
component_inputs = []
|
||||
component_outputs = []
|
||||
is_output_component = False
|
||||
for node_id, data in graph.items():
|
||||
if data["class_type"] == "ComponentMetadata":
|
||||
component_display_name = data["inputs"].get("name", component_raw_name)
|
||||
is_output_component = data["inputs"].get("always_output", False)
|
||||
elif data["class_type"] == "ComponentInput":
|
||||
data_type = data["inputs"]["data_type"]
|
||||
if len(data_type) > 0 and data_type[0] == "[":
|
||||
try:
|
||||
data_type = json.loads(data_type)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
extra_args = json.loads(data["inputs"]["extra_args"])
|
||||
except:
|
||||
extra_args = {}
|
||||
component_inputs.append({
|
||||
"node_id": node_id,
|
||||
"name": data["inputs"]["name"],
|
||||
"data_type": data_type,
|
||||
"extra_args": extra_args,
|
||||
"explicit_input_order": data["inputs"]["explicit_input_order"],
|
||||
"optional": data["inputs"]["optional"],
|
||||
})
|
||||
elif data["class_type"] == "ComponentOutput":
|
||||
component_outputs.append({
|
||||
"node_id": node_id,
|
||||
"index": data["inputs"]["index"],
|
||||
"data_type": data["inputs"]["data_type"],
|
||||
})
|
||||
component_inputs.sort(key=lambda x: (x["explicit_input_order"], x["name"]))
|
||||
component_outputs.sort(key=lambda x: x["index"])
|
||||
for i in range(1, len(component_inputs)):
|
||||
if component_inputs[i]["name"] == component_inputs[i-1]["name"]:
|
||||
raise Exception("Component input name is not unique: {}".format(component_inputs[i]["name"]))
|
||||
for i in range(1, len(component_outputs)):
|
||||
if component_outputs[i]["index"] == component_outputs[i-1]["index"]:
|
||||
raise Exception("Component output index is not unique: {}".format(component_outputs[i]["index"]))
|
||||
except Exception as e:
|
||||
print("Error loading component file: {}: {}".format(component_file, e))
|
||||
return None
|
||||
|
||||
class ComponentNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if not node["optional"]},
|
||||
"optional": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if node["optional"]},
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple([node["data_type"] for node in component_outputs])
|
||||
FUNCTION = "expand_component"
|
||||
|
||||
CATEGORY = "Custom Components"
|
||||
OUTPUT_NODE = is_output_component
|
||||
|
||||
def expand_component(self, **kwargs):
|
||||
new_graph = copy.deepcopy(graph)
|
||||
for input_node in component_inputs:
|
||||
if input_node["name"] in kwargs:
|
||||
new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]]
|
||||
return {
|
||||
"result": tuple([[node["node_id"], 0] for node in component_outputs]),
|
||||
"expand": new_graph,
|
||||
}
|
||||
ComponentNode.__name__ = component_raw_name
|
||||
COMPONENT_NODE_CLASS_MAPPINGS[component_raw_name] = ComponentNode
|
||||
COMPONENT_NODE_DISPLAY_NAME_MAPPINGS[component_raw_name] = component_display_name
|
||||
print("Loaded component: {}".format(component_display_name))
|
||||
|
||||
def load_components():
|
||||
component_dir = os.path.join(comfy_path, "components")
|
||||
files = [f for f in os.listdir(component_dir) if os.path.isfile(os.path.join(component_dir, f)) and f.endswith(".json")]
|
||||
for f in files:
|
||||
print("Loading component file %s" % f)
|
||||
LoadComponent(os.path.join(component_dir, f))
|
||||
|
||||
load_components()
|
||||
131
custom_nodes/execution-inversion-demo-comfyui/flow_control.py
Normal file
131
custom_nodes/execution-inversion-demo-comfyui/flow_control.py
Normal file
@ -0,0 +1,131 @@
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
|
||||
NUM_FLOW_SOCKETS = 5
|
||||
class WhileLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"condition": ("INT", {"default": 1, "min": 0, "max": 1, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"]["initial_value%d" % i] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
|
||||
FUNCTION = "while_loop_open"
|
||||
|
||||
CATEGORY = "Flow Control"
|
||||
|
||||
def while_loop_open(self, condition, **kwargs):
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get("initial_value%d" % i, None))
|
||||
return tuple(["stub"] + values)
|
||||
|
||||
class WhileLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL",),
|
||||
"condition": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"dynprompt": "DYNPROMPT",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"]["initial_value%d" % i] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
|
||||
FUNCTION = "while_loop_close"
|
||||
|
||||
CATEGORY = "Flow Control"
|
||||
|
||||
def explore_dependencies(self, node_id, dynprompt, upstream):
|
||||
node_info = dynprompt.get_node(node_id)
|
||||
if "inputs" not in node_info:
|
||||
return
|
||||
for k, v in node_info["inputs"].items():
|
||||
if isinstance(v, list) and len(v) == 2:
|
||||
parent_id = v[0]
|
||||
if parent_id not in upstream:
|
||||
upstream[parent_id] = []
|
||||
self.explore_dependencies(parent_id, dynprompt, upstream)
|
||||
upstream[parent_id].append(node_id)
|
||||
|
||||
def collect_contained(self, node_id, upstream, contained):
|
||||
if node_id not in upstream:
|
||||
return
|
||||
for child_id in upstream[node_id]:
|
||||
if child_id not in contained:
|
||||
contained[child_id] = True
|
||||
self.collect_contained(child_id, upstream, contained)
|
||||
|
||||
|
||||
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
|
||||
if not condition:
|
||||
# We're done with the loop
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get("initial_value%d" % i, None))
|
||||
return tuple(values)
|
||||
|
||||
# We want to loop
|
||||
this_node = dynprompt.get_node(unique_id)
|
||||
upstream = {}
|
||||
# Get the list of all nodes between the open and close nodes
|
||||
self.explore_dependencies(unique_id, dynprompt, upstream)
|
||||
|
||||
contained = {}
|
||||
open_node = this_node["inputs"]["flow_control"][0]
|
||||
self.collect_contained(open_node, upstream, contained)
|
||||
contained[unique_id] = True
|
||||
contained[open_node] = True
|
||||
|
||||
graph = GraphBuilder()
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.node(original_node["class_type"], node_id)
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.lookup_node(node_id)
|
||||
for k, v in original_node["inputs"].items():
|
||||
if isinstance(v, list) and len(v) == 2 and v[0] in contained:
|
||||
parent = graph.lookup_node(v[0])
|
||||
node.set_input(k, parent.out(v[1]))
|
||||
else:
|
||||
node.set_input(k, v)
|
||||
new_open = graph.lookup_node(open_node)
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
key = "initial_value%d" % i
|
||||
new_open.set_input(key, kwargs.get(key, None))
|
||||
my_clone = graph.lookup_node(unique_id)
|
||||
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
|
||||
return {
|
||||
"result": tuple(result),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
|
||||
"WhileLoopOpen": WhileLoopOpen,
|
||||
"WhileLoopClose": WhileLoopClose,
|
||||
}
|
||||
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WhileLoopOpen": "While Loop Open",
|
||||
"WhileLoopClose": "While Loop Close",
|
||||
}
|
||||
@ -0,0 +1,64 @@
|
||||
import { app } from "/scripts/app.js";
|
||||
import { ComfyDialog, $el } from "/scripts/ui.js";
|
||||
import {ComfyWidgets} from "../../scripts/widgets.js";
|
||||
|
||||
var update_comfyui_button = null;
|
||||
var fetch_updates_button = null;
|
||||
|
||||
const fileInput = $el("input", {
|
||||
id: "component-file-input",
|
||||
type: "file",
|
||||
accept: ".json,image/png,.latent,.safetensors",
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
onchange: async () => {
|
||||
app.handleFile(fileInput.files[0]);
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => {
|
||||
app.loadGraphData(JSON.parse(reader.result)["workflow"]);
|
||||
};
|
||||
reader.readAsText(fileInput.files[0]);
|
||||
},
|
||||
});
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.InversionDemoComponents",
|
||||
|
||||
async setup() {
|
||||
const menu = document.querySelector(".comfy-menu");
|
||||
const separator = document.createElement("hr");
|
||||
|
||||
separator.style.margin = "20px 0";
|
||||
separator.style.width = "100%";
|
||||
menu.append(separator);
|
||||
|
||||
const saveButton = document.createElement("button");
|
||||
saveButton.textContent = "Save Component";
|
||||
saveButton.onclick = async () => {
|
||||
let filename = "component.json";
|
||||
const p = await app.graphToPrompt();
|
||||
const json = JSON.stringify(p, null, 2); // convert the data to a JSON string
|
||||
const blob = new Blob([json], {type: "application/json"});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = $el("a", {
|
||||
href: url,
|
||||
download: filename,
|
||||
style: {display: "none"},
|
||||
parent: document.body,
|
||||
});
|
||||
a.click();
|
||||
setTimeout(function () {
|
||||
a.remove();
|
||||
window.URL.revokeObjectURL(url);
|
||||
}, 0);
|
||||
};
|
||||
|
||||
const loadButton = document.createElement("button");
|
||||
loadButton.textContent = "Load Component";
|
||||
loadButton.onclick = () => {
|
||||
fileInput.click();
|
||||
};
|
||||
menu.append(saveButton);
|
||||
menu.append(loadButton);
|
||||
}
|
||||
});
|
||||
206
custom_nodes/execution-inversion-demo-comfyui/nodes.py
Normal file
206
custom_nodes/execution-inversion-demo-comfyui/nodes.py
Normal file
@ -0,0 +1,206 @@
|
||||
import re
|
||||
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
|
||||
class InversionDemoAdvancedPromptNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
"model": ("MODEL",),
|
||||
"clip": ("CLIP",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING")
|
||||
FUNCTION = "advanced_prompt"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def parse_prompt(self, prompt):
|
||||
# Get all string pieces matching the pattern "<lora:(name):(strength)(:(clip_strength))?>"
|
||||
# where name is a string and strength is a float
|
||||
# and clip_strength is an optional float
|
||||
pattern = r"<lora:([^:]+):([-0-9.]+)(?::([-0-9.]+))?>"
|
||||
loras = re.findall(pattern, prompt)
|
||||
if len(loras) == 0:
|
||||
return prompt, loras
|
||||
cleaned_prompt = re.sub(pattern, "", prompt).strip()
|
||||
print("Cleaned prompt: '%s'" % cleaned_prompt)
|
||||
return cleaned_prompt, loras
|
||||
|
||||
|
||||
def advanced_prompt(self, prompt, clip, model):
|
||||
cleaned_prompt, loras = self.parse_prompt(prompt)
|
||||
graph = GraphBuilder()
|
||||
for lora in loras:
|
||||
lora_name = lora[0]
|
||||
lora_model_strength = float(lora[1])
|
||||
lora_clip_strength = lora_model_strength if lora[2] == "" else float(lora[2])
|
||||
|
||||
loader = graph.node("LoraLoader", model=model, clip=clip, lora_name = lora_name, strength_model = lora_model_strength, strength_clip = lora_clip_strength)
|
||||
model = loader.out(0)
|
||||
clip = loader.out(1)
|
||||
encoder = graph.node("CLIPTextEncode", clip=clip, text=cleaned_prompt)
|
||||
return {
|
||||
"result": (model, clip, encoder.out(0)),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
class InversionDemoFakeAdvancedPromptNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"prompt": ("STRING", {"multiline": True}),
|
||||
"clip": ("CLIP",),
|
||||
"model": ("MODEL",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING")
|
||||
FUNCTION = "advanced_prompt"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def advanced_prompt(self, prompt, clip, model):
|
||||
tokens = clip.tokenize(prompt)
|
||||
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
|
||||
return (model, clip, [[cond, {"pooled_output": pooled}]])
|
||||
|
||||
class InversionDemoLazySwitch:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"switch": ([False, True],),
|
||||
"on_false": ("*", {"lazy": True}),
|
||||
"on_true": ("*", {"lazy": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "switch"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def check_lazy_status(self, switch, on_false = None, on_true = None):
|
||||
if switch and on_true is None:
|
||||
return ["on_true"]
|
||||
if not switch and on_false is None:
|
||||
return ["on_false"]
|
||||
|
||||
def switch(self, switch, on_false = None, on_true = None):
|
||||
value = on_true if switch else on_false
|
||||
return (value,)
|
||||
|
||||
class InversionDemoLazyIndexSwitch:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"index": ("INT", {"default": 0, "min": 0, "max": 9, "step": 1}),
|
||||
"value0": ("*", {"lazy": True}),
|
||||
},
|
||||
"optional": {
|
||||
"value1": ("*", {"lazy": True}),
|
||||
"value2": ("*", {"lazy": True}),
|
||||
"value3": ("*", {"lazy": True}),
|
||||
"value4": ("*", {"lazy": True}),
|
||||
"value5": ("*", {"lazy": True}),
|
||||
"value6": ("*", {"lazy": True}),
|
||||
"value7": ("*", {"lazy": True}),
|
||||
"value8": ("*", {"lazy": True}),
|
||||
"value9": ("*", {"lazy": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "index_switch"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def check_lazy_status(self, index, **kwargs):
|
||||
key = "value%d" % index
|
||||
if key not in kwargs:
|
||||
return [key]
|
||||
|
||||
def index_switch(self, index, **kwargs):
|
||||
key = "value%d" % index
|
||||
return kwargs[key]
|
||||
|
||||
class InversionDemoLazyMixImages:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",{"lazy": True}),
|
||||
"image2": ("IMAGE",{"lazy": True}),
|
||||
"mask": ("MASK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mix"
|
||||
|
||||
CATEGORY = "InversionDemo Nodes"
|
||||
|
||||
def check_lazy_status(self, mask, image1 = None, image2 = None):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
needed = []
|
||||
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
|
||||
needed.append("image1")
|
||||
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
|
||||
needed.append("image2")
|
||||
return needed
|
||||
|
||||
# Not trying to handle different batch sizes here just to keep the demo simple
|
||||
def mix(self, mask, image1 = None, image2 = None):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
if mask_min == 0.0 and mask_max == 0.0:
|
||||
return (image1,)
|
||||
elif mask_min == 1.0 and mask_max == 1.0:
|
||||
return (image2,)
|
||||
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask.unsqueeze(3)
|
||||
if mask.shape[3] < image1.shape[3]:
|
||||
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
||||
|
||||
return (image1 * (1. - mask) + image2 * mask,)
|
||||
|
||||
GENERAL_NODE_CLASS_MAPPINGS = {
|
||||
"InversionDemoAdvancedPromptNode": InversionDemoAdvancedPromptNode,
|
||||
"InversionDemoFakeAdvancedPromptNode": InversionDemoFakeAdvancedPromptNode,
|
||||
"InversionDemoLazySwitch": InversionDemoLazySwitch,
|
||||
"InversionDemoLazyIndexSwitch": InversionDemoLazyIndexSwitch,
|
||||
"InversionDemoLazyMixImages": InversionDemoLazyMixImages,
|
||||
}
|
||||
|
||||
GENERAL_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"InversionDemoAdvancedPromptNode": "Advanced Prompt",
|
||||
"InversionDemoFakeAdvancedPromptNode": "Fake Advanced Prompt",
|
||||
"InversionDemoLazySwitch": "Lazy Switch",
|
||||
"InversionDemoLazyIndexSwitch": "Lazy Index Switch",
|
||||
"InversionDemoLazyMixImages": "Lazy Mix Images",
|
||||
}
|
||||
355
execution.py
355
execution.py
@ -7,13 +7,150 @@ import heapq
|
||||
import traceback
|
||||
import gc
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.graph_utils
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
SLEEPING = 2
|
||||
|
||||
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
# it can still be returned to the graph after having further dependencies added.
|
||||
class ExecutionList:
|
||||
def __init__(self, dynprompt, outputs):
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs = outputs
|
||||
self.staged_node_id = None
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if input_name in valid_inputs["required"]:
|
||||
input_category = "required"
|
||||
input_info = valid_inputs["required"][input_name]
|
||||
elif input_name in valid_inputs["optional"]:
|
||||
input_category = "optional"
|
||||
input_info = valid_inputs["optional"][input_name]
|
||||
elif input_name in valid_inputs["hidden"]:
|
||||
input_category = "hidden"
|
||||
input_info = valid_inputs["hidden"][input_name]
|
||||
if input_info is None:
|
||||
return None, None, None
|
||||
input_type = input_info[0]
|
||||
extra_info = None
|
||||
if len(input_info) > 1:
|
||||
extra_info = input_info[1]
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input))
|
||||
value = inputs[to_input]
|
||||
if not isinstance(value, list):
|
||||
raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input))
|
||||
from_node_id, from_socket = value
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if from_node_id in self.outputs:
|
||||
# Nothing to do
|
||||
return
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if isinstance(value, list):
|
||||
from_node_id, from_socket = value
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
if input_info is None or "lazy" not in input_info or not input_info["lazy"]:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None
|
||||
available = [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
if len(available) == 0:
|
||||
raise Exception("Dependency cycle detected")
|
||||
next_node = available[0]
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
for node_id in available:
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
next_node = node_id
|
||||
break
|
||||
self.staged_node_id = next_node
|
||||
return self.staged_node_id
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
del self.pendingNodes[node_id]
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[node_id]
|
||||
self.staged_node_id = None
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
self.original_prompt = original_prompt
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
return None
|
||||
|
||||
def add_ephemeral_node(self, real_parent_id, node_id, node_info):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = real_parent_id
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
if node_id in self.ephemeral_parents:
|
||||
return self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
for x in inputs:
|
||||
@ -22,7 +159,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
return None
|
||||
continue # This might be a lazily-evaluated input
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
@ -34,6 +171,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
@ -68,39 +207,54 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
return results
|
||||
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
results = []
|
||||
uis = []
|
||||
subgraph_results = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
if 'expand' in r:
|
||||
# Perform an expansion, but do not append results
|
||||
has_subgraph = True
|
||||
new_graph = r['expand']
|
||||
subgraph_results.append((new_graph, r.get("result", None)))
|
||||
elif 'result' in r:
|
||||
results.append(r['result'])
|
||||
subgraph_results.append((None, r['result']))
|
||||
else:
|
||||
results.append(r)
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
if has_subgraph:
|
||||
output = subgraph_results
|
||||
elif len(results) > 0:
|
||||
output = merge_result_data(results, obj)
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
return output, ui, has_subgraph
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
@ -110,53 +264,102 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage):
|
||||
def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return (True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
|
||||
if unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
if not is_subgraph:
|
||||
resolved_outputs.append(result)
|
||||
else:
|
||||
resolved_output = []
|
||||
for r in result:
|
||||
if isinstance(r, list) and len(r) == 2:
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = outputs[source_node][source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
else:
|
||||
resolved_output.append(r)
|
||||
resolved_outputs.append(tuple(resolved_output))
|
||||
output_data = merge_result_data(resolved_outputs, class_def)
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = real_node_id
|
||||
server.send_sync("executing", { "node": real_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||
required_inputs = [x for x in required_inputs if x not in input_data_all]
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.SLEEPING, None, None)
|
||||
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all)
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
server.send_sync("executed", { "node": real_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
for i in range(len(output_data)):
|
||||
new_graph, node_outputs = output_data[i]
|
||||
if new_graph is None:
|
||||
cached_outputs.append((False, node_outputs))
|
||||
else:
|
||||
# Check for conflicts
|
||||
for node_id in new_graph.keys():
|
||||
if dynprompt.get_node(node_id) is not None:
|
||||
new_graph, node_outputs = comfy.graph_utils.add_graph_prefix(new_graph, node_outputs, "%s.%d." % (unique_id, i))
|
||||
break
|
||||
new_output_ids = []
|
||||
for node_id, node_info in new_graph.items():
|
||||
dynprompt.add_ephemeral_node(real_node_id, node_id, node_info)
|
||||
# Figure out if the newly created node is an output node
|
||||
class_type = node_info["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
new_output_ids.append(node_id)
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for i in range(len(node_outputs)):
|
||||
if isinstance(node_outputs[i], list) and len(node_outputs[i]) == 2:
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
execution_list.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
cached_outputs.append((True, node_outputs))
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.SLEEPING, None, None)
|
||||
outputs[unique_id] = output_data
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
}
|
||||
|
||||
return (False, error_details, iex)
|
||||
return (ExecutionResult.FAILURE, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
@ -174,35 +377,18 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
print(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
}
|
||||
return (False, error_details, ex)
|
||||
return (ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return (True, None, None)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
will_execute = []
|
||||
if unique_id in outputs:
|
||||
return []
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id)
|
||||
|
||||
return will_execute + [unique_id]
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
@ -350,28 +536,27 @@ class PromptExecutor:
|
||||
|
||||
if self.server.client_id is not None:
|
||||
self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
|
||||
pending_subgraph_results = {}
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
execution_list = ExecutionList(dynamic_prompt, self.outputs)
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while len(to_execute) > 0:
|
||||
#always execute the output that depends on the least amount of unexecuted nodes first
|
||||
to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
while not execution_list.is_empty():
|
||||
node_id = execution_list.stage_node_execution()
|
||||
result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.outputs, node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage, execution_list, pending_subgraph_results)
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.SLEEPING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS
|
||||
execution_list.complete_node_execution()
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
if x in prompt:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user