Calculate grid from combinatorial inputs

This commit is contained in:
space-nuko 2023-06-09 11:44:16 -05:00
parent 2fa87f1779
commit c9f4eb3fad
3 changed files with 169 additions and 22 deletions

View File

@ -8,6 +8,9 @@ import traceback
import gc
import time
import itertools
from typing import List, Dict
import dataclasses
from dataclasses import dataclass
import torch
import nodes
@ -15,6 +18,15 @@ import nodes
import comfy.model_management
@dataclass
class CombinatorialBatches:
batches: List
input_to_index: Dict
index_to_values: Dict
indices: List
combinations: List
def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input
@ -22,6 +34,7 @@ def get_input_data_batches(input_data_all):
input_to_index = {}
index_to_values = []
index_to_coords = []
# Sort by input name first so the order which batch inputs are applied can
# be easily calculated (node execution order first, then alphabetical input
@ -34,15 +47,18 @@ def get_input_data_batches(input_data_all):
if isinstance(value, dict) and "combinatorial" in value:
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
i += 1
if len(index_to_values) == 0:
# No combinatorial options.
return [input_data_all]
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
batches = []
for combination in list(itertools.product(*index_to_values)):
indices = list(itertools.product(*index_to_coords))
combinations = list(itertools.product(*index_to_values))
for combination in combinations:
batch = {}
for input_name, value in input_data_all.items():
if isinstance(value, dict) and "combinatorial" in value:
@ -53,7 +69,7 @@ def get_input_data_batches(input_data_all):
batch[input_name] = value
batches.append(batch)
return batches
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
"""Given input data from the prompt, returns a list of input data dicts for
@ -157,9 +173,9 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
all_outputs = []
all_outputs_ui = []
total_batches = len(input_data_all_batches)
total_batches = len(input_data_all_batches.batches)
for batch_num, batch in enumerate(input_data_all_batches):
for batch_num, batch in enumerate(input_data_all_batches.batches):
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
uis = []
@ -208,6 +224,9 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
"batch_num": batch_num,
"total_batches": total_batches
}
if input_data_all_batches.indices:
message["indices"] = input_data_all_batches.indices[batch_num]
message["combination"] = input_data_all_batches.combinations[batch_num]
server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui
@ -240,12 +259,25 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
# Another node failed further upstream
return result
input_data_all = None
input_data_all_batches = None
try:
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None:
server.last_node_id = unique_id
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id)
combinations = None
if input_data_all_batches.indices:
combinations = {
"input_to_index": input_data_all_batches.input_to_index,
"index_to_values": input_data_all_batches.index_to_values,
"indices": input_data_all_batches.indices
}
mes = {
"node": unique_id,
"prompt_id": prompt_id,
"combinations": combinations
}
server.send_sync("executing", mes, server.client_id)
obj = class_def()
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
@ -266,15 +298,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
input_data_formatted = {}
if input_data_all is not None:
input_data_formatted = {}
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
input_data_formatted = []
if input_data_all_batches is not None:
d = {}
for batch in input_data_all_batches.batches:
for name, inputs in batch.items():
d[name] = [format_value(x) for x in inputs]
input_data_formatted.append(d)
output_data_formatted = {}
output_data_formatted = []
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
d = {}
for batch_outputs in node_outputs:
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
output_data_formatted.append(d)
print("!!! Exception during processing !!!")
print(traceback.format_exc())
@ -327,7 +364,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
if input_data_all_batches is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
for batch in input_data_all_batches:
for batch in input_data_all_batches.batches:
if map_node_over_list(class_def, batch, "IS_CHANGED"):
is_changed = True
break
@ -668,8 +705,7 @@ def validate_inputs(prompt, item, validated):
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for batch in input_data_all_batches:
for batch in input_data_all_batches.batches:
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
for r in ret:
if r != True:

View File

@ -109,7 +109,7 @@ class ComfyApi extends EventTarget {
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
break;
case "executed":
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));

View File

@ -44,6 +44,12 @@ export class ComfyApp {
*/
this.nodeOutputs = {};
/**
* Stores the grid data for each node
* @type {Record<string, any>}
*/
this.nodeGrids = {};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
@ -949,13 +955,21 @@ export class ComfyApp {
api.addEventListener("executing", ({ detail }) => {
this.progress = null;
this.runningNodeId = detail;
this.runningNodeId = detail.node;
this.graph.setDirtyCanvas(true, false);
delete this.nodePreviewImages[this.runningNodeId]
if (detail.node != null) {
delete this.nodePreviewImages[this.runningNodeId]
}
else {
this.runningPrompt = null;
}
});
api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output;
if (detail.output != null) {
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
@ -964,6 +978,7 @@ export class ComfyApp {
});
api.addEventListener("execution_start", ({ detail }) => {
this.nodeGrids = {}
this.runningNodeId = null;
this.lastExecutionError = null
});
@ -988,6 +1003,93 @@ export class ComfyApp {
api.init();
}
/*
* Based on inputs in the prompt marked as combinatorial,
* construct a grid from the results;
*/
#resolveGrid(outputNode, output, runningPrompt) {
let axes = []
const allImages = output.filter(batch => Array.isArray(batch.images))
.flatMap(batch => batch.images)
if (allImages.length === 0)
return null;
console.error("PROMPT", runningPrompt);
console.error("OUTPUT", output);
const isInputLink = (input) => {
return Array.isArray(input)
&& input.length === 2
&& typeof input[0] === "string"
&& typeof input[1] === "number";
}
// Axes closer to the output (executed later) are discovered first
const queue = [outputNode]
while (queue.length > 0) {
const nodeID = queue.pop();
const promptInput = runningPrompt.output[nodeID];
// Ensure input keys are sorted alphanumerically
// This is important for the plot to have the same order as
// it was executed on the backend
let sortedKeys = Object.keys(promptInput.inputs);
sortedKeys.sort((a, b) => a.localeCompare(b));
// Then reverse the order since we're traversing the graph upstream,
// so execution order comes out backwards
sortedKeys = sortedKeys.reverse();
for (const inputName of sortedKeys) {
const input = promptInput.inputs[inputName];
if (typeof input === "object" && "__inputType__" in input) {
axes.push({
nodeID,
inputName,
values: input.values
})
}
else if (isInputLink(input)) {
const inputNodeID = input[0]
queue.push(inputNodeID)
}
}
}
axes = axes.reverse();
// Now divide up the image outputs
// Each axis will divide the full array of images by N, where N was the
// number of combinatorial choices for that axis, and this happens
// recursively for each axis
console.error("AXES", axes)
// Grid position
const currentCoords = Array.from(Array(0))
let images = allImages.map(i => { return {
image: i,
coords: []
}})
let factor = 1
for (const axis of axes) {
factor *= axis.values.length;
for (const [index, image] of images.entries()) {
image.coords.push(Math.floor((index / factor) * axis.values.length) % axis.values.length);
}
}
const grid = { axes, images };
console.error("GRID", grid);
return null;
}
#addKeyboardHandler() {
window.addEventListener("keydown", (e) => {
this.shiftDown = e.shiftKey;
@ -1292,7 +1394,11 @@ export class ComfyApp {
const workflow = this.graph.serialize();
const output = {};
// Process nodes in order of execution
for (const node of this.graph.computeExecutionOrder(false)) {
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
const executionOrderIds = executionOrder.map(n => n.id);
for (const node of executionOrder) {
const n = workflow.nodes.find((n) => n.id === node.id);
if (node.isVirtualNode) {
@ -1359,7 +1465,7 @@ export class ComfyApp {
}
}
return { workflow, output };
return { workflow, output, executionOrder: executionOrderIds };
}
#formatPromptError(error) {
@ -1409,6 +1515,7 @@ export class ComfyApp {
this.#processingQueue = true;
this.lastPromptError = null;
this.runningPrompt = null;
try {
while (this.#queueItems.length) {
@ -1418,8 +1525,10 @@ export class ComfyApp {
const p = await this.graphToPrompt();
try {
this.runningPrompt = p;
await api.queuePrompt(number, p);
} catch (error) {
this.runningPrompt = null;
const formattedError = this.#formatPromptError(error)
this.ui.dialog.show(formattedError);
if (error.response) {
@ -1528,10 +1637,12 @@ export class ComfyApp {
*/
clean() {
this.nodeOutputs = {};
this.nodeGrids = {};
this.nodePreviewImages = {}
this.lastPromptError = null;
this.lastExecutionError = null;
this.runningNodeId = null;
this.runningPrompt = null;
}
}