diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py new file mode 100644 index 000000000..d1a4e7187 --- /dev/null +++ b/comfy/graph_utils.py @@ -0,0 +1,104 @@ +import json +import random + +# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end +class GraphBuilder: + def __init__(self, prefix = True): + if isinstance(prefix, str): + self.prefix = prefix + elif prefix: + self.prefix = "%d.%d." % (random.randint(0, 0xffffffffffffffff), random.randint(0, 0xffffffffffffffff)) + else: + self.prefix = "" + self.nodes = {} + self.id_gen = 1 + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if isinstance(value, list) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def serialize(self): + return { + "class_type": self.class_type, + "inputs": self.inputs + } + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = { "class_type": node_info["class_type"], "inputs": {} } + for input_name, input_value in node_info.get("inputs", {}).items(): + if isinstance(input_value, list): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if isinstance(output, list): # This is a node link + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) + diff --git a/custom_nodes/execution-inversion-demo-comfyui/__init__.py b/custom_nodes/execution-inversion-demo-comfyui/__init__.py new file mode 100644 index 000000000..00356fc41 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/__init__.py @@ -0,0 +1,19 @@ +from .nodes import GENERAL_NODE_CLASS_MAPPINGS, GENERAL_NODE_DISPLAY_NAME_MAPPINGS +from .components import setup_js, COMPONENT_NODE_CLASS_MAPPINGS, COMPONENT_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(GENERAL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(GENERAL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) + +setup_js() + diff --git a/custom_nodes/execution-inversion-demo-comfyui/components.py b/custom_nodes/execution-inversion-demo-comfyui/components.py new file mode 100644 index 000000000..1a25fc6ba --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/components.py @@ -0,0 +1,208 @@ +import os +import shutil +import folder_paths +import json +import copy + +comfy_path = os.path.dirname(folder_paths.__file__) +js_path = os.path.join(comfy_path, "web", "extensions") +inversion_demo_path = os.path.dirname(__file__) + +def setup_js(): + # setup js + js_dest_path = os.path.join(js_path, "inversion-demo-components") + if not os.path.exists(js_dest_path): + os.makedirs(js_dest_path) + js_src_path = os.path.join(inversion_demo_path, "js", "inversion-demo-components.js") + shutil.copy(js_src_path, js_dest_path) + +class ComponentInput: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "name": ("STRING", {"multiline": False}), + "data_type": ("STRING", {"multiline": False, "default": "IMAGE"}), + "extra_args": ("STRING", {"multiline": False}), + "explicit_input_order": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), + "optional": ([False, True],), + }, + "optional": { + "default_value": ("*",), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "component_input" + + CATEGORY = "Component Creation" + + def component_input(self, name, data_type, extra_args, explicit_input_order, optional, default_value = None): + return (default_value,) + +class ComponentOutput: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), + "data_type": ("STRING", {"multiline": False, "default": "IMAGE"}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "component_output" + + CATEGORY = "Component Creation" + + def component_output(self, index, data_type, value): + return (value,) + +class ComponentMetadata: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "name": ("STRING", {"multiline": False}), + "always_output": ([False, True],), + }, + } + + RETURN_TYPES = () + FUNCTION = "nop" + + CATEGORY = "Component Creation" + + def nop(self, name): + return {} + +COMPONENT_NODE_CLASS_MAPPINGS = { + "ComponentInput": ComponentInput, + "ComponentOutput": ComponentOutput, + "ComponentMetadata": ComponentMetadata, +} +COMPONENT_NODE_DISPLAY_NAME_MAPPINGS = { + "ComponentInput": "Component Input", + "ComponentOutput": "Component Output", + "ComponentMetadata": "Component Metadata", +} + +DEFAULT_EXTRA_DATA = { + "STRING": {"multiline": False}, + "INT": {"default": 0, "min": 0, "max": 1000, "step": 1}, + "FLOAT": {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}, +} + +def default_extra_data(data_type, extra_args): + if data_type == "STRING": + args = {"multiline": False} + elif data_type == "INT": + args = {"default": 0, "min": -1000000, "max": 1000000, "step": 1} + elif data_type == "FLOAT": + args = {"default": 0.0, "min": -1000000.0, "max": 1000000.0, "step": 0.1} + else: + args = {} + args.update(extra_args) + return args + +def LoadComponent(component_file): + try: + with open(component_file, "r") as f: + component_data = f.read() + graph = json.loads(component_data)["output"] + + component_raw_name = os.path.basename(component_file).split(".")[0] + component_display_name = component_raw_name + component_inputs = [] + component_outputs = [] + is_output_component = False + for node_id, data in graph.items(): + if data["class_type"] == "ComponentMetadata": + component_display_name = data["inputs"].get("name", component_raw_name) + is_output_component = data["inputs"].get("always_output", False) + elif data["class_type"] == "ComponentInput": + data_type = data["inputs"]["data_type"] + if len(data_type) > 0 and data_type[0] == "[": + try: + data_type = json.loads(data_type) + except: + pass + try: + extra_args = json.loads(data["inputs"]["extra_args"]) + except: + extra_args = {} + component_inputs.append({ + "node_id": node_id, + "name": data["inputs"]["name"], + "data_type": data_type, + "extra_args": extra_args, + "explicit_input_order": data["inputs"]["explicit_input_order"], + "optional": data["inputs"]["optional"], + }) + elif data["class_type"] == "ComponentOutput": + component_outputs.append({ + "node_id": node_id, + "index": data["inputs"]["index"], + "data_type": data["inputs"]["data_type"], + }) + component_inputs.sort(key=lambda x: (x["explicit_input_order"], x["name"])) + component_outputs.sort(key=lambda x: x["index"]) + for i in range(1, len(component_inputs)): + if component_inputs[i]["name"] == component_inputs[i-1]["name"]: + raise Exception("Component input name is not unique: {}".format(component_inputs[i]["name"])) + for i in range(1, len(component_outputs)): + if component_outputs[i]["index"] == component_outputs[i-1]["index"]: + raise Exception("Component output index is not unique: {}".format(component_outputs[i]["index"])) + except Exception as e: + print("Error loading component file: {}: {}".format(component_file, e)) + return None + + class ComponentNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if not node["optional"]}, + "optional": {node["name"]: (node["data_type"], default_extra_data(node["data_type"], node["extra_args"])) for node in component_inputs if node["optional"]}, + } + + RETURN_TYPES = tuple([node["data_type"] for node in component_outputs]) + FUNCTION = "expand_component" + + CATEGORY = "Custom Components" + OUTPUT_NODE = is_output_component + + def expand_component(self, **kwargs): + new_graph = copy.deepcopy(graph) + for input_node in component_inputs: + if input_node["name"] in kwargs: + new_graph[input_node["node_id"]]["inputs"]["default_value"] = kwargs[input_node["name"]] + return { + "result": tuple([[node["node_id"], 0] for node in component_outputs]), + "expand": new_graph, + } + ComponentNode.__name__ = component_raw_name + COMPONENT_NODE_CLASS_MAPPINGS[component_raw_name] = ComponentNode + COMPONENT_NODE_DISPLAY_NAME_MAPPINGS[component_raw_name] = component_display_name + print("Loaded component: {}".format(component_display_name)) + +def load_components(): + component_dir = os.path.join(comfy_path, "components") + files = [f for f in os.listdir(component_dir) if os.path.isfile(os.path.join(component_dir, f)) and f.endswith(".json")] + for f in files: + print("Loading component file %s" % f) + LoadComponent(os.path.join(component_dir, f)) + +load_components() diff --git a/custom_nodes/execution-inversion-demo-comfyui/flow_control.py b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py new file mode 100644 index 000000000..7876ae488 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/flow_control.py @@ -0,0 +1,131 @@ +from comfy.graph_utils import GraphBuilder + +NUM_FLOW_SOCKETS = 5 +class WhileLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("INT", {"default": 1, "min": 0, "max": 1, "step": 1}), + }, + "optional": { + }, + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) + FUNCTION = "while_loop_open" + + CATEGORY = "Flow Control" + + def while_loop_open(self, condition, **kwargs): + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(["stub"] + values) + +class WhileLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow_control": ("FLOW_CONTROL",), + "condition": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1}), + }, + "optional": { + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + } + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) + FUNCTION = "while_loop_close" + + CATEGORY = "Flow Control" + + def explore_dependencies(self, node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if isinstance(v, list) and len(v) == 2: + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + self.explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + def collect_contained(self, node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + self.collect_contained(child_id, upstream, contained) + + + def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs): + if not condition: + # We're done with the loop + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(values) + + # We want to loop + this_node = dynprompt.get_node(unique_id) + upstream = {} + # Get the list of all nodes between the open and close nodes + self.explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = this_node["inputs"]["flow_control"][0] + self.collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node(node_id) + for k, v in original_node["inputs"].items(): + if isinstance(v, list) and len(v) == 2 and v[0] in contained: + parent = graph.lookup_node(v[0]) + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + new_open = graph.lookup_node(open_node) + for i in range(NUM_FLOW_SOCKETS): + key = "initial_value%d" % i + new_open.set_input(key, kwargs.get(key, None)) + my_clone = graph.lookup_node(unique_id) + result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) + return { + "result": tuple(result), + "expand": graph.finalize(), + } + +FLOW_CONTROL_NODE_CLASS_MAPPINGS = { + "WhileLoopOpen": WhileLoopOpen, + "WhileLoopClose": WhileLoopClose, +} +FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { + "WhileLoopOpen": "While Loop Open", + "WhileLoopClose": "While Loop Close", +} diff --git a/custom_nodes/execution-inversion-demo-comfyui/js/inversion-demo-components.js b/custom_nodes/execution-inversion-demo-comfyui/js/inversion-demo-components.js new file mode 100644 index 000000000..cf72e670c --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/js/inversion-demo-components.js @@ -0,0 +1,64 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import {ComfyWidgets} from "../../scripts/widgets.js"; + +var update_comfyui_button = null; +var fetch_updates_button = null; + +const fileInput = $el("input", { + id: "component-file-input", + type: "file", + accept: ".json,image/png,.latent,.safetensors", + style: {display: "none"}, + parent: document.body, + onchange: async () => { + app.handleFile(fileInput.files[0]); + const reader = new FileReader(); + reader.onload = () => { + app.loadGraphData(JSON.parse(reader.result)["workflow"]); + }; + reader.readAsText(fileInput.files[0]); + }, +}); + +app.registerExtension({ + name: "Comfy.InversionDemoComponents", + + async setup() { + const menu = document.querySelector(".comfy-menu"); + const separator = document.createElement("hr"); + + separator.style.margin = "20px 0"; + separator.style.width = "100%"; + menu.append(separator); + + const saveButton = document.createElement("button"); + saveButton.textContent = "Save Component"; + saveButton.onclick = async () => { + let filename = "component.json"; + const p = await app.graphToPrompt(); + const json = JSON.stringify(p, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: filename, + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); + }; + + const loadButton = document.createElement("button"); + loadButton.textContent = "Load Component"; + loadButton.onclick = () => { + fileInput.click(); + }; + menu.append(saveButton); + menu.append(loadButton); + } +}); diff --git a/custom_nodes/execution-inversion-demo-comfyui/nodes.py b/custom_nodes/execution-inversion-demo-comfyui/nodes.py new file mode 100644 index 000000000..771a5f954 --- /dev/null +++ b/custom_nodes/execution-inversion-demo-comfyui/nodes.py @@ -0,0 +1,206 @@ +import re + +from comfy.graph_utils import GraphBuilder + +class InversionDemoAdvancedPromptNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "prompt": ("STRING", {"multiline": True}), + "model": ("MODEL",), + "clip": ("CLIP",), + }, + } + + RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING") + FUNCTION = "advanced_prompt" + + CATEGORY = "InversionDemo Nodes" + + def parse_prompt(self, prompt): + # Get all string pieces matching the pattern "" + # where name is a string and strength is a float + # and clip_strength is an optional float + pattern = r"" + loras = re.findall(pattern, prompt) + if len(loras) == 0: + return prompt, loras + cleaned_prompt = re.sub(pattern, "", prompt).strip() + print("Cleaned prompt: '%s'" % cleaned_prompt) + return cleaned_prompt, loras + + + def advanced_prompt(self, prompt, clip, model): + cleaned_prompt, loras = self.parse_prompt(prompt) + graph = GraphBuilder() + for lora in loras: + lora_name = lora[0] + lora_model_strength = float(lora[1]) + lora_clip_strength = lora_model_strength if lora[2] == "" else float(lora[2]) + + loader = graph.node("LoraLoader", model=model, clip=clip, lora_name = lora_name, strength_model = lora_model_strength, strength_clip = lora_clip_strength) + model = loader.out(0) + clip = loader.out(1) + encoder = graph.node("CLIPTextEncode", clip=clip, text=cleaned_prompt) + return { + "result": (model, clip, encoder.out(0)), + "expand": graph.finalize(), + } + +class InversionDemoFakeAdvancedPromptNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "prompt": ("STRING", {"multiline": True}), + "clip": ("CLIP",), + "model": ("MODEL",), + }, + } + + RETURN_TYPES = ("MODEL", "CLIP", "CONDITIONING") + FUNCTION = "advanced_prompt" + + CATEGORY = "InversionDemo Nodes" + + def advanced_prompt(self, prompt, clip, model): + tokens = clip.tokenize(prompt) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return (model, clip, [[cond, {"pooled_output": pooled}]]) + +class InversionDemoLazySwitch: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "switch": ([False, True],), + "on_false": ("*", {"lazy": True}), + "on_true": ("*", {"lazy": True}), + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "switch" + + CATEGORY = "InversionDemo Nodes" + + def check_lazy_status(self, switch, on_false = None, on_true = None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + def switch(self, switch, on_false = None, on_true = None): + value = on_true if switch else on_false + return (value,) + +class InversionDemoLazyIndexSwitch: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "index": ("INT", {"default": 0, "min": 0, "max": 9, "step": 1}), + "value0": ("*", {"lazy": True}), + }, + "optional": { + "value1": ("*", {"lazy": True}), + "value2": ("*", {"lazy": True}), + "value3": ("*", {"lazy": True}), + "value4": ("*", {"lazy": True}), + "value5": ("*", {"lazy": True}), + "value6": ("*", {"lazy": True}), + "value7": ("*", {"lazy": True}), + "value8": ("*", {"lazy": True}), + "value9": ("*", {"lazy": True}), + } + } + + RETURN_TYPES = ("*",) + FUNCTION = "index_switch" + + CATEGORY = "InversionDemo Nodes" + + def check_lazy_status(self, index, **kwargs): + key = "value%d" % index + if key not in kwargs: + return [key] + + def index_switch(self, index, **kwargs): + key = "value%d" % index + return kwargs[key] + +class InversionDemoLazyMixImages: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",{"lazy": True}), + "image2": ("IMAGE",{"lazy": True}), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mix" + + CATEGORY = "InversionDemo Nodes" + + def check_lazy_status(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + needed = [] + if image1 is None and (mask_min != 1.0 or mask_max != 1.0): + needed.append("image1") + if image2 is None and (mask_min != 0.0 or mask_max != 0.0): + needed.append("image2") + return needed + + # Not trying to handle different batch sizes here just to keep the demo simple + def mix(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + if mask_min == 0.0 and mask_max == 0.0: + return (image1,) + elif mask_min == 1.0 and mask_max == 1.0: + return (image2,) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if len(mask.shape) == 3: + mask = mask.unsqueeze(3) + if mask.shape[3] < image1.shape[3]: + mask = mask.repeat(1, 1, 1, image1.shape[3]) + + return (image1 * (1. - mask) + image2 * mask,) + +GENERAL_NODE_CLASS_MAPPINGS = { + "InversionDemoAdvancedPromptNode": InversionDemoAdvancedPromptNode, + "InversionDemoFakeAdvancedPromptNode": InversionDemoFakeAdvancedPromptNode, + "InversionDemoLazySwitch": InversionDemoLazySwitch, + "InversionDemoLazyIndexSwitch": InversionDemoLazyIndexSwitch, + "InversionDemoLazyMixImages": InversionDemoLazyMixImages, +} + +GENERAL_NODE_DISPLAY_NAME_MAPPINGS = { + "InversionDemoAdvancedPromptNode": "Advanced Prompt", + "InversionDemoFakeAdvancedPromptNode": "Fake Advanced Prompt", + "InversionDemoLazySwitch": "Lazy Switch", + "InversionDemoLazyIndexSwitch": "Lazy Index Switch", + "InversionDemoLazyMixImages": "Lazy Mix Images", +} diff --git a/execution.py b/execution.py index 94bcc30bc..35b542b28 100644 --- a/execution.py +++ b/execution.py @@ -7,13 +7,150 @@ import heapq import traceback import gc import time +from enum import Enum import torch import nodes import comfy.model_management +import comfy.graph_utils -def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): +class ExecutionResult(Enum): + SUCCESS = 0 + FAILURE = 1 + SLEEPING = 2 + +# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, +# it can still be returned to the graph after having further dependencies added. +class ExecutionList: + def __init__(self, dynprompt, outputs): + self.dynprompt = dynprompt + self.outputs = outputs + self.staged_node_id = None + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + extra_info = None + if len(input_info) > 1: + extra_info = input_info[1] + return input_type, input_category, extra_info + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise Exception("Node %s says it needs input %s, but there is no input to that node at all" % (to_node_id, to_input)) + value = inputs[to_input] + if not isinstance(value, list): + raise Exception("Node %s says it needs input %s, but that value is a constant" % (to_node_id, to_input)) + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if from_node_id in self.outputs: + # Nothing to do + return + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if isinstance(value, list): + from_node_id, from_socket = value + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + if input_info is None or "lazy" not in input_info or not input_info["lazy"]: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None + available = [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + if len(available) == 0: + raise Exception("Dependency cycle detected") + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + del self.pendingNodes[node_id] + for blocked_node_id in self.blocking[node_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[node_id] + self.staged_node_id = None + + def is_empty(self): + return len(self.pendingNodes) == 0 + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + return None + + def add_ephemeral_node(self, real_parent_id, node_id, node_info): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = real_parent_id + + def get_real_node_id(self, node_id): + if node_id in self.ephemeral_parents: + return self.ephemeral_parents[node_id] + return node_id + +def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: @@ -22,7 +159,7 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - return None + continue # This might be a lazily-evaluated input obj = outputs[input_unique_id][output_index] input_data_all[x] = obj else: @@ -34,6 +171,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da for x in h: if h[x] == "PROMPT": input_data_all[x] = [prompt] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: input_data_all[x] = [extra_data['extra_pnginfo']] @@ -68,39 +207,54 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) return results +def merge_result_data(results, obj): + # check which outputs need concatenating + output = [] + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + return output + def get_output_data(obj, input_data_all): results = [] uis = [] + subgraph_results = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - - for r in return_values: + has_subgraph = False + for i in range(len(return_values)): + r = return_values[i] if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) - if 'result' in r: + if 'expand' in r: + # Perform an expansion, but do not append results + has_subgraph = True + new_graph = r['expand'] + subgraph_results.append((new_graph, r.get("result", None))) + elif 'result' in r: results.append(r['result']) + subgraph_results.append((None, r['result'])) else: results.append(r) - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - + if has_subgraph: + output = subgraph_results + elif len(results) > 0: + output = merge_result_data(results, obj) + else: + output = [] ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + return output, ui, has_subgraph def format_value(x): if x is None: @@ -110,53 +264,102 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage): +def non_recursive_execute(server, dynprompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, object_storage, execution_list, pending_subgraph_results): unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + real_node_id = dynprompt.get_real_node_id(unique_id) + inputs = dynprompt.get_node(unique_id)['inputs'] + class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return (True, None, None) - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, object_storage) - if result[0] is not True: - # Another node failed further upstream - return result + return (ExecutionResult.SUCCESS, None, None) input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + if unique_id in pending_subgraph_results: + cached_results = pending_subgraph_results[unique_id] + resolved_outputs = [] + for is_subgraph, result in cached_results: + if not is_subgraph: + resolved_outputs.append(result) + else: + resolved_output = [] + for r in result: + if isinstance(r, list) and len(r) == 2: + source_node, source_output = r[0], r[1] + node_output = outputs[source_node][source_output] + for o in node_output: + resolved_output.append(o) - obj = object_storage.get((unique_id, class_type), None) - if obj is None: - obj = class_def() - object_storage[(unique_id, class_type)] = obj + else: + resolved_output.append(r) + resolved_outputs.append(tuple(resolved_output)) + output_data = merge_result_data(resolved_outputs, class_def) + output_ui = [] + has_subgraph = False + else: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, dynprompt.original_prompt, dynprompt, extra_data) + if server.client_id is not None: + server.last_node_id = real_node_id + server.send_sync("executing", { "node": real_node_id, "prompt_id": prompt_id }, server.client_id) - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data + obj = object_storage.get((unique_id, class_type), None) + if obj is None: + obj = class_def() + object_storage[(unique_id, class_type)] = obj + + if hasattr(obj, "check_lazy_status"): + required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) + required_inputs = [x for x in required_inputs if x not in input_data_all] + if len(required_inputs) > 0: + for i in required_inputs: + execution_list.make_input_strong_link(unique_id, i) + return (ExecutionResult.SLEEPING, None, None) + + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all) if len(output_ui) > 0: outputs_ui[unique_id] = output_ui if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executed", { "node": real_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + if has_subgraph: + cached_outputs = [] + for i in range(len(output_data)): + new_graph, node_outputs = output_data[i] + if new_graph is None: + cached_outputs.append((False, node_outputs)) + else: + # Check for conflicts + for node_id in new_graph.keys(): + if dynprompt.get_node(node_id) is not None: + new_graph, node_outputs = comfy.graph_utils.add_graph_prefix(new_graph, node_outputs, "%s.%d." % (unique_id, i)) + break + new_output_ids = [] + for node_id, node_info in new_graph.items(): + dynprompt.add_ephemeral_node(real_node_id, node_id, node_info) + # Figure out if the newly created node is an output node + class_type = node_info["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + new_output_ids.append(node_id) + for node_id in new_output_ids: + execution_list.add_node(node_id) + for i in range(len(node_outputs)): + if isinstance(node_outputs[i], list) and len(node_outputs[i]) == 2: + from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + execution_list.add_strong_link(from_node_id, from_socket, unique_id) + cached_outputs.append((True, node_outputs)) + pending_subgraph_results[unique_id] = cached_outputs + return (ExecutionResult.SLEEPING, None, None) + outputs[unique_id] = output_data except comfy.model_management.InterruptProcessingException as iex: print("Processing interrupted") # skip formatting inputs/outputs error_details = { - "node_id": unique_id, + "node_id": real_node_id, } - return (False, error_details, iex) + return (ExecutionResult.FAILURE, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -174,35 +377,18 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute print(traceback.format_exc()) error_details = { - "node_id": unique_id, + "node_id": real_node_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, "current_outputs": output_data_formatted } - return (False, error_details, ex) + return (ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) - return (True, None, None) - -def recursive_will_execute(prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id) - - return will_execute + [unique_id] + return (ExecutionResult.SUCCESS, None, None) def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): unique_id = current_item @@ -350,28 +536,27 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + pending_subgraph_results = {} + dynamic_prompt = DynamicPrompt(prompt) executed = set() - output_node_id = None - to_execute = [] - + execution_list = ExecutionList(dynamic_prompt, self.outputs) for node_id in list(execute_outputs): - to_execute += [(0, node_id)] + execution_list.add_node(node_id) - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - output_node_id = to_execute.pop(0)[-1] - - # This call shouldn't raise anything if there's an error deep in - # the actual SD code, instead it will report the node where the - # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + while not execution_list.is_empty(): + node_id = execution_list.stage_node_execution() + result, error, ex = non_recursive_execute(self.server, dynamic_prompt, self.outputs, node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage, execution_list, pending_subgraph_results) + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break + elif result == ExecutionResult.SLEEPING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS + execution_list.complete_node_execution() for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + if x in prompt: + self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None