Calculate grid from combinatorial inputs

This commit is contained in:
space-nuko 2023-06-09 11:44:16 -05:00
parent 2fa87f1779
commit c9f4eb3fad
3 changed files with 169 additions and 22 deletions

View File

@ -8,6 +8,9 @@ import traceback
import gc import gc
import time import time
import itertools import itertools
from typing import List, Dict
import dataclasses
from dataclasses import dataclass
import torch import torch
import nodes import nodes
@ -15,6 +18,15 @@ import nodes
import comfy.model_management import comfy.model_management
@dataclass
class CombinatorialBatches:
batches: List
input_to_index: Dict
index_to_values: Dict
indices: List
combinations: List
def get_input_data_batches(input_data_all): def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all """Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input the possible batches that can be made by combining the different input
@ -22,6 +34,7 @@ def get_input_data_batches(input_data_all):
input_to_index = {} input_to_index = {}
index_to_values = [] index_to_values = []
index_to_coords = []
# Sort by input name first so the order which batch inputs are applied can # Sort by input name first so the order which batch inputs are applied can
# be easily calculated (node execution order first, then alphabetical input # be easily calculated (node execution order first, then alphabetical input
@ -34,15 +47,18 @@ def get_input_data_batches(input_data_all):
if isinstance(value, dict) and "combinatorial" in value: if isinstance(value, dict) and "combinatorial" in value:
input_to_index[input_name] = i input_to_index[input_name] = i
index_to_values.append(value["values"]) index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
i += 1 i += 1
if len(index_to_values) == 0: if len(index_to_values) == 0:
# No combinatorial options. # No combinatorial options.
return [input_data_all] return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
batches = [] batches = []
for combination in list(itertools.product(*index_to_values)): indices = list(itertools.product(*index_to_coords))
combinations = list(itertools.product(*index_to_values))
for combination in combinations:
batch = {} batch = {}
for input_name, value in input_data_all.items(): for input_name, value in input_data_all.items():
if isinstance(value, dict) and "combinatorial" in value: if isinstance(value, dict) and "combinatorial" in value:
@ -53,7 +69,7 @@ def get_input_data_batches(input_data_all):
batch[input_name] = value batch[input_name] = value
batches.append(batch) batches.append(batch)
return batches return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
"""Given input data from the prompt, returns a list of input data dicts for """Given input data from the prompt, returns a list of input data dicts for
@ -157,9 +173,9 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
all_outputs = [] all_outputs = []
all_outputs_ui = [] all_outputs_ui = []
total_batches = len(input_data_all_batches) total_batches = len(input_data_all_batches.batches)
for batch_num, batch in enumerate(input_data_all_batches): for batch_num, batch in enumerate(input_data_all_batches.batches):
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
uis = [] uis = []
@ -208,6 +224,9 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
"batch_num": batch_num, "batch_num": batch_num,
"total_batches": total_batches "total_batches": total_batches
} }
if input_data_all_batches.indices:
message["indices"] = input_data_all_batches.indices[batch_num]
message["combination"] = input_data_all_batches.combinations[batch_num]
server.send_sync("executed", message, server.client_id) server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui return all_outputs, all_outputs_ui
@ -240,12 +259,25 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
# Another node failed further upstream # Another node failed further upstream
return result return result
input_data_all = None input_data_all_batches = None
try: try:
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = unique_id server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id) combinations = None
if input_data_all_batches.indices:
combinations = {
"input_to_index": input_data_all_batches.input_to_index,
"index_to_values": input_data_all_batches.index_to_values,
"indices": input_data_all_batches.indices
}
mes = {
"node": unique_id,
"prompt_id": prompt_id,
"combinations": combinations
}
server.send_sync("executing", mes, server.client_id)
obj = class_def() obj = class_def()
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
@ -266,15 +298,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
except Exception as ex: except Exception as ex:
typ, _, tb = sys.exc_info() typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ) exception_type = full_type_name(typ)
input_data_formatted = {} input_data_formatted = []
if input_data_all is not None: if input_data_all_batches is not None:
input_data_formatted = {} d = {}
for name, inputs in input_data_all.items(): for batch in input_data_all_batches.batches:
input_data_formatted[name] = [format_value(x) for x in inputs] for name, inputs in batch.items():
d[name] = [format_value(x) for x in inputs]
input_data_formatted.append(d)
output_data_formatted = {} output_data_formatted = []
for node_id, node_outputs in outputs.items(): for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] d = {}
for batch_outputs in node_outputs:
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
output_data_formatted.append(d)
print("!!! Exception during processing !!!") print("!!! Exception during processing !!!")
print(traceback.format_exc()) print(traceback.format_exc())
@ -327,7 +364,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
if input_data_all_batches is not None: if input_data_all_batches is not None:
try: try:
#is_changed = class_def.IS_CHANGED(**input_data_all) #is_changed = class_def.IS_CHANGED(**input_data_all)
for batch in input_data_all_batches: for batch in input_data_all_batches.batches:
if map_node_over_list(class_def, batch, "IS_CHANGED"): if map_node_over_list(class_def, batch, "IS_CHANGED"):
is_changed = True is_changed = True
break break
@ -668,8 +705,7 @@ def validate_inputs(prompt, item, validated):
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all_batches = get_input_data(inputs, obj_class, unique_id) input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all) #ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for batch in input_data_all_batches.batches:
for batch in input_data_all_batches:
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
for r in ret: for r in ret:
if r != True: if r != True:

View File

@ -109,7 +109,7 @@ class ComfyApi extends EventTarget {
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break; break;
case "executing": case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
break; break;
case "executed": case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/ */
this.nodeOutputs = {}; this.nodeOutputs = {};
/**
* Stores the grid data for each node
* @type {Record<string, any>}
*/
this.nodeGrids = {};
/** /**
* Stores the preview image data for each node * Stores the preview image data for each node
* @type {Record<string, Image>} * @type {Record<string, Image>}
@ -949,13 +955,21 @@ export class ComfyApp {
api.addEventListener("executing", ({ detail }) => { api.addEventListener("executing", ({ detail }) => {
this.progress = null; this.progress = null;
this.runningNodeId = detail; this.runningNodeId = detail.node;
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId] if (detail.node != null) {
delete this.nodePreviewImages[this.runningNodeId]
}
else {
this.runningPrompt = null;
}
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output; this.nodeOutputs[detail.node] = detail.output;
if (detail.output != null) {
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
}
const node = this.graph.getNodeById(detail.node); const node = this.graph.getNodeById(detail.node);
if (node) { if (node) {
if (node.onExecuted) if (node.onExecuted)
@ -964,6 +978,7 @@ export class ComfyApp {
}); });
api.addEventListener("execution_start", ({ detail }) => { api.addEventListener("execution_start", ({ detail }) => {
this.nodeGrids = {}
this.runningNodeId = null; this.runningNodeId = null;
this.lastExecutionError = null this.lastExecutionError = null
}); });
@ -988,6 +1003,93 @@ export class ComfyApp {
api.init(); api.init();
} }
/*
* Based on inputs in the prompt marked as combinatorial,
* construct a grid from the results;
*/
#resolveGrid(outputNode, output, runningPrompt) {
let axes = []
const allImages = output.filter(batch => Array.isArray(batch.images))
.flatMap(batch => batch.images)
if (allImages.length === 0)
return null;
console.error("PROMPT", runningPrompt);
console.error("OUTPUT", output);
const isInputLink = (input) => {
return Array.isArray(input)
&& input.length === 2
&& typeof input[0] === "string"
&& typeof input[1] === "number";
}
// Axes closer to the output (executed later) are discovered first
const queue = [outputNode]
while (queue.length > 0) {
const nodeID = queue.pop();
const promptInput = runningPrompt.output[nodeID];
// Ensure input keys are sorted alphanumerically
// This is important for the plot to have the same order as
// it was executed on the backend
let sortedKeys = Object.keys(promptInput.inputs);
sortedKeys.sort((a, b) => a.localeCompare(b));
// Then reverse the order since we're traversing the graph upstream,
// so execution order comes out backwards
sortedKeys = sortedKeys.reverse();
for (const inputName of sortedKeys) {
const input = promptInput.inputs[inputName];
if (typeof input === "object" && "__inputType__" in input) {
axes.push({
nodeID,
inputName,
values: input.values
})
}
else if (isInputLink(input)) {
const inputNodeID = input[0]
queue.push(inputNodeID)
}
}
}
axes = axes.reverse();
// Now divide up the image outputs
// Each axis will divide the full array of images by N, where N was the
// number of combinatorial choices for that axis, and this happens
// recursively for each axis
console.error("AXES", axes)
// Grid position
const currentCoords = Array.from(Array(0))
let images = allImages.map(i => { return {
image: i,
coords: []
}})
let factor = 1
for (const axis of axes) {
factor *= axis.values.length;
for (const [index, image] of images.entries()) {
image.coords.push(Math.floor((index / factor) * axis.values.length) % axis.values.length);
}
}
const grid = { axes, images };
console.error("GRID", grid);
return null;
}
#addKeyboardHandler() { #addKeyboardHandler() {
window.addEventListener("keydown", (e) => { window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey; this.shiftDown = e.shiftKey;
@ -1292,7 +1394,11 @@ export class ComfyApp {
const workflow = this.graph.serialize(); const workflow = this.graph.serialize();
const output = {}; const output = {};
// Process nodes in order of execution // Process nodes in order of execution
for (const node of this.graph.computeExecutionOrder(false)) {
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
const executionOrderIds = executionOrder.map(n => n.id);
for (const node of executionOrder) {
const n = workflow.nodes.find((n) => n.id === node.id); const n = workflow.nodes.find((n) => n.id === node.id);
if (node.isVirtualNode) { if (node.isVirtualNode) {
@ -1359,7 +1465,7 @@ export class ComfyApp {
} }
} }
return { workflow, output }; return { workflow, output, executionOrder: executionOrderIds };
} }
#formatPromptError(error) { #formatPromptError(error) {
@ -1409,6 +1515,7 @@ export class ComfyApp {
this.#processingQueue = true; this.#processingQueue = true;
this.lastPromptError = null; this.lastPromptError = null;
this.runningPrompt = null;
try { try {
while (this.#queueItems.length) { while (this.#queueItems.length) {
@ -1418,8 +1525,10 @@ export class ComfyApp {
const p = await this.graphToPrompt(); const p = await this.graphToPrompt();
try { try {
this.runningPrompt = p;
await api.queuePrompt(number, p); await api.queuePrompt(number, p);
} catch (error) { } catch (error) {
this.runningPrompt = null;
const formattedError = this.#formatPromptError(error) const formattedError = this.#formatPromptError(error)
this.ui.dialog.show(formattedError); this.ui.dialog.show(formattedError);
if (error.response) { if (error.response) {
@ -1528,10 +1637,12 @@ export class ComfyApp {
*/ */
clean() { clean() {
this.nodeOutputs = {}; this.nodeOutputs = {};
this.nodeGrids = {};
this.nodePreviewImages = {} this.nodePreviewImages = {}
this.lastPromptError = null; this.lastPromptError = null;
this.lastExecutionError = null; this.lastExecutionError = null;
this.runningNodeId = null; this.runningNodeId = null;
this.runningPrompt = null;
} }
} }