From c9f4eb3fada5afdf620459b7402ae8b550790246 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 9 Jun 2023 11:44:16 -0500 Subject: [PATCH] Calculate grid from combinatorial inputs --- execution.py | 70 +++++++++++++++++++------- web/scripts/api.js | 2 +- web/scripts/app.js | 119 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 169 insertions(+), 22 deletions(-) diff --git a/execution.py b/execution.py index 5f3f011d5..6893188e5 100644 --- a/execution.py +++ b/execution.py @@ -8,6 +8,9 @@ import traceback import gc import time import itertools +from typing import List, Dict +import dataclasses +from dataclasses import dataclass import torch import nodes @@ -15,6 +18,15 @@ import nodes import comfy.model_management +@dataclass +class CombinatorialBatches: + batches: List + input_to_index: Dict + index_to_values: Dict + indices: List + combinations: List + + def get_input_data_batches(input_data_all): """Given input data that can contain combinatorial input values, returns all the possible batches that can be made by combining the different input @@ -22,6 +34,7 @@ def get_input_data_batches(input_data_all): input_to_index = {} index_to_values = [] + index_to_coords = [] # Sort by input name first so the order which batch inputs are applied can # be easily calculated (node execution order first, then alphabetical input @@ -34,15 +47,18 @@ def get_input_data_batches(input_data_all): if isinstance(value, dict) and "combinatorial" in value: input_to_index[input_name] = i index_to_values.append(value["values"]) + index_to_coords.append(list(range(len(value["values"])))) i += 1 if len(index_to_values) == 0: # No combinatorial options. - return [input_data_all] + return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None) batches = [] - for combination in list(itertools.product(*index_to_values)): + indices = list(itertools.product(*index_to_coords)) + combinations = list(itertools.product(*index_to_values)) + for combination in combinations: batch = {} for input_name, value in input_data_all.items(): if isinstance(value, dict) and "combinatorial" in value: @@ -53,7 +69,7 @@ def get_input_data_batches(input_data_all): batch[input_name] = value batches.append(batch) - return batches + return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations) def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): """Given input data from the prompt, returns a list of input data dicts for @@ -157,9 +173,9 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): all_outputs = [] all_outputs_ui = [] - total_batches = len(input_data_all_batches) + total_batches = len(input_data_all_batches.batches) - for batch_num, batch in enumerate(input_data_all_batches): + for batch_num, batch in enumerate(input_data_all_batches.batches): return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) uis = [] @@ -208,6 +224,9 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): "batch_num": batch_num, "total_batches": total_batches } + if input_data_all_batches.indices: + message["indices"] = input_data_all_batches.indices[batch_num] + message["combination"] = input_data_all_batches.combinations[batch_num] server.send_sync("executed", message, server.client_id) return all_outputs, all_outputs_ui @@ -240,12 +259,25 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute # Another node failed further upstream return result - input_data_all = None + input_data_all_batches = None try: input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id) + combinations = None + if input_data_all_batches.indices: + combinations = { + "input_to_index": input_data_all_batches.input_to_index, + "index_to_values": input_data_all_batches.index_to_values, + "indices": input_data_all_batches.indices + } + mes = { + "node": unique_id, + "prompt_id": prompt_id, + "combinations": combinations + } + server.send_sync("executing", mes, server.client_id) + obj = class_def() output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) @@ -266,15 +298,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) - input_data_formatted = {} - if input_data_all is not None: - input_data_formatted = {} - for name, inputs in input_data_all.items(): - input_data_formatted[name] = [format_value(x) for x in inputs] + input_data_formatted = [] + if input_data_all_batches is not None: + d = {} + for batch in input_data_all_batches.batches: + for name, inputs in batch.items(): + d[name] = [format_value(x) for x in inputs] + input_data_formatted.append(d) - output_data_formatted = {} + output_data_formatted = [] for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + d = {} + for batch_outputs in node_outputs: + d[node_id] = [[format_value(x) for x in l] for l in batch_outputs] + output_data_formatted.append(d) print("!!! Exception during processing !!!") print(traceback.format_exc()) @@ -327,7 +364,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item if input_data_all_batches is not None: try: #is_changed = class_def.IS_CHANGED(**input_data_all) - for batch in input_data_all_batches: + for batch in input_data_all_batches.batches: if map_node_over_list(class_def, batch, "IS_CHANGED"): is_changed = True break @@ -668,8 +705,7 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all_batches = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for batch in input_data_all_batches: + for batch in input_data_all_batches.batches: ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") for r in ret: if r != True: diff --git a/web/scripts/api.js b/web/scripts/api.js index 8313f1abe..7897df5bb 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -109,7 +109,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index c1ecb7d82..edd8fa617 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -44,6 +44,12 @@ export class ComfyApp { */ this.nodeOutputs = {}; + /** + * Stores the grid data for each node + * @type {Record} + */ + this.nodeGrids = {}; + /** * Stores the preview image data for each node * @type {Record} @@ -949,13 +955,21 @@ export class ComfyApp { api.addEventListener("executing", ({ detail }) => { this.progress = null; - this.runningNodeId = detail; + this.runningNodeId = detail.node; this.graph.setDirtyCanvas(true, false); - delete this.nodePreviewImages[this.runningNodeId] + if (detail.node != null) { + delete this.nodePreviewImages[this.runningNodeId] + } + else { + this.runningPrompt = null; + } }); api.addEventListener("executed", ({ detail }) => { this.nodeOutputs[detail.node] = detail.output; + if (detail.output != null) { + this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt) + } const node = this.graph.getNodeById(detail.node); if (node) { if (node.onExecuted) @@ -964,6 +978,7 @@ export class ComfyApp { }); api.addEventListener("execution_start", ({ detail }) => { + this.nodeGrids = {} this.runningNodeId = null; this.lastExecutionError = null }); @@ -988,6 +1003,93 @@ export class ComfyApp { api.init(); } + /* + * Based on inputs in the prompt marked as combinatorial, + * construct a grid from the results; + */ + #resolveGrid(outputNode, output, runningPrompt) { + let axes = [] + + const allImages = output.filter(batch => Array.isArray(batch.images)) + .flatMap(batch => batch.images) + + if (allImages.length === 0) + return null; + + console.error("PROMPT", runningPrompt); + console.error("OUTPUT", output); + + const isInputLink = (input) => { + return Array.isArray(input) + && input.length === 2 + && typeof input[0] === "string" + && typeof input[1] === "number"; + } + + // Axes closer to the output (executed later) are discovered first + const queue = [outputNode] + while (queue.length > 0) { + const nodeID = queue.pop(); + const promptInput = runningPrompt.output[nodeID]; + + // Ensure input keys are sorted alphanumerically + // This is important for the plot to have the same order as + // it was executed on the backend + let sortedKeys = Object.keys(promptInput.inputs); + sortedKeys.sort((a, b) => a.localeCompare(b)); + + // Then reverse the order since we're traversing the graph upstream, + // so execution order comes out backwards + sortedKeys = sortedKeys.reverse(); + + for (const inputName of sortedKeys) { + const input = promptInput.inputs[inputName]; + if (typeof input === "object" && "__inputType__" in input) { + axes.push({ + nodeID, + inputName, + values: input.values + }) + } + else if (isInputLink(input)) { + const inputNodeID = input[0] + queue.push(inputNodeID) + } + } + } + + axes = axes.reverse(); + + // Now divide up the image outputs + // Each axis will divide the full array of images by N, where N was the + // number of combinatorial choices for that axis, and this happens + // recursively for each axis + + console.error("AXES", axes) + + // Grid position + const currentCoords = Array.from(Array(0)) + + let images = allImages.map(i => { return { + image: i, + coords: [] + }}) + + let factor = 1 + + for (const axis of axes) { + factor *= axis.values.length; + for (const [index, image] of images.entries()) { + image.coords.push(Math.floor((index / factor) * axis.values.length) % axis.values.length); + } + } + + const grid = { axes, images }; + console.error("GRID", grid); + + return null; + } + #addKeyboardHandler() { window.addEventListener("keydown", (e) => { this.shiftDown = e.shiftKey; @@ -1292,7 +1394,11 @@ export class ComfyApp { const workflow = this.graph.serialize(); const output = {}; // Process nodes in order of execution - for (const node of this.graph.computeExecutionOrder(false)) { + + const executionOrder = Array.from(this.graph.computeExecutionOrder(false)); + const executionOrderIds = executionOrder.map(n => n.id); + + for (const node of executionOrder) { const n = workflow.nodes.find((n) => n.id === node.id); if (node.isVirtualNode) { @@ -1359,7 +1465,7 @@ export class ComfyApp { } } - return { workflow, output }; + return { workflow, output, executionOrder: executionOrderIds }; } #formatPromptError(error) { @@ -1409,6 +1515,7 @@ export class ComfyApp { this.#processingQueue = true; this.lastPromptError = null; + this.runningPrompt = null; try { while (this.#queueItems.length) { @@ -1418,8 +1525,10 @@ export class ComfyApp { const p = await this.graphToPrompt(); try { + this.runningPrompt = p; await api.queuePrompt(number, p); } catch (error) { + this.runningPrompt = null; const formattedError = this.#formatPromptError(error) this.ui.dialog.show(formattedError); if (error.response) { @@ -1528,10 +1637,12 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.nodeGrids = {}; this.nodePreviewImages = {} this.lastPromptError = null; this.lastExecutionError = null; this.runningNodeId = null; + this.runningPrompt = null; } }