mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Calculate grid from combinatorial inputs
This commit is contained in:
parent
2fa87f1779
commit
c9f4eb3fad
70
execution.py
70
execution.py
@ -8,6 +8,9 @@ import traceback
|
||||
import gc
|
||||
import time
|
||||
import itertools
|
||||
from typing import List, Dict
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
@ -15,6 +18,15 @@ import nodes
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombinatorialBatches:
|
||||
batches: List
|
||||
input_to_index: Dict
|
||||
index_to_values: Dict
|
||||
indices: List
|
||||
combinations: List
|
||||
|
||||
|
||||
def get_input_data_batches(input_data_all):
|
||||
"""Given input data that can contain combinatorial input values, returns all
|
||||
the possible batches that can be made by combining the different input
|
||||
@ -22,6 +34,7 @@ def get_input_data_batches(input_data_all):
|
||||
|
||||
input_to_index = {}
|
||||
index_to_values = []
|
||||
index_to_coords = []
|
||||
|
||||
# Sort by input name first so the order which batch inputs are applied can
|
||||
# be easily calculated (node execution order first, then alphabetical input
|
||||
@ -34,15 +47,18 @@ def get_input_data_batches(input_data_all):
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
input_to_index[input_name] = i
|
||||
index_to_values.append(value["values"])
|
||||
index_to_coords.append(list(range(len(value["values"]))))
|
||||
i += 1
|
||||
|
||||
if len(index_to_values) == 0:
|
||||
# No combinatorial options.
|
||||
return [input_data_all]
|
||||
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
|
||||
|
||||
batches = []
|
||||
|
||||
for combination in list(itertools.product(*index_to_values)):
|
||||
indices = list(itertools.product(*index_to_coords))
|
||||
combinations = list(itertools.product(*index_to_values))
|
||||
for combination in combinations:
|
||||
batch = {}
|
||||
for input_name, value in input_data_all.items():
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
@ -53,7 +69,7 @@ def get_input_data_batches(input_data_all):
|
||||
batch[input_name] = value
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
||||
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||
"""Given input data from the prompt, returns a list of input data dicts for
|
||||
@ -157,9 +173,9 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
all_outputs = []
|
||||
all_outputs_ui = []
|
||||
total_batches = len(input_data_all_batches)
|
||||
total_batches = len(input_data_all_batches.batches)
|
||||
|
||||
for batch_num, batch in enumerate(input_data_all_batches):
|
||||
for batch_num, batch in enumerate(input_data_all_batches.batches):
|
||||
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
uis = []
|
||||
@ -208,6 +224,9 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
"batch_num": batch_num,
|
||||
"total_batches": total_batches
|
||||
}
|
||||
if input_data_all_batches.indices:
|
||||
message["indices"] = input_data_all_batches.indices[batch_num]
|
||||
message["combination"] = input_data_all_batches.combinations[batch_num]
|
||||
server.send_sync("executed", message, server.client_id)
|
||||
|
||||
return all_outputs, all_outputs_ui
|
||||
@ -240,12 +259,25 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
|
||||
input_data_all = None
|
||||
input_data_all_batches = None
|
||||
try:
|
||||
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id)
|
||||
combinations = None
|
||||
if input_data_all_batches.indices:
|
||||
combinations = {
|
||||
"input_to_index": input_data_all_batches.input_to_index,
|
||||
"index_to_values": input_data_all_batches.index_to_values,
|
||||
"indices": input_data_all_batches.indices
|
||||
}
|
||||
mes = {
|
||||
"node": unique_id,
|
||||
"prompt_id": prompt_id,
|
||||
"combinations": combinations
|
||||
}
|
||||
server.send_sync("executing", mes, server.client_id)
|
||||
|
||||
obj = class_def()
|
||||
|
||||
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
||||
@ -266,15 +298,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
input_data_formatted = {}
|
||||
if input_data_all is not None:
|
||||
input_data_formatted = {}
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
input_data_formatted = []
|
||||
if input_data_all_batches is not None:
|
||||
d = {}
|
||||
for batch in input_data_all_batches.batches:
|
||||
for name, inputs in batch.items():
|
||||
d[name] = [format_value(x) for x in inputs]
|
||||
input_data_formatted.append(d)
|
||||
|
||||
output_data_formatted = {}
|
||||
output_data_formatted = []
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
d = {}
|
||||
for batch_outputs in node_outputs:
|
||||
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
|
||||
output_data_formatted.append(d)
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
@ -327,7 +364,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
if input_data_all_batches is not None:
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
for batch in input_data_all_batches:
|
||||
for batch in input_data_all_batches.batches:
|
||||
if map_node_over_list(class_def, batch, "IS_CHANGED"):
|
||||
is_changed = True
|
||||
break
|
||||
@ -668,8 +705,7 @@ def validate_inputs(prompt, item, validated):
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for batch in input_data_all_batches:
|
||||
for batch in input_data_all_batches.batches:
|
||||
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
|
||||
@ -109,7 +109,7 @@ class ComfyApi extends EventTarget {
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
|
||||
@ -44,6 +44,12 @@ export class ComfyApp {
|
||||
*/
|
||||
this.nodeOutputs = {};
|
||||
|
||||
/**
|
||||
* Stores the grid data for each node
|
||||
* @type {Record<string, any>}
|
||||
*/
|
||||
this.nodeGrids = {};
|
||||
|
||||
/**
|
||||
* Stores the preview image data for each node
|
||||
* @type {Record<string, Image>}
|
||||
@ -949,13 +955,21 @@ export class ComfyApp {
|
||||
|
||||
api.addEventListener("executing", ({ detail }) => {
|
||||
this.progress = null;
|
||||
this.runningNodeId = detail;
|
||||
this.runningNodeId = detail.node;
|
||||
this.graph.setDirtyCanvas(true, false);
|
||||
delete this.nodePreviewImages[this.runningNodeId]
|
||||
if (detail.node != null) {
|
||||
delete this.nodePreviewImages[this.runningNodeId]
|
||||
}
|
||||
else {
|
||||
this.runningPrompt = null;
|
||||
}
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
if (detail.output != null) {
|
||||
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
@ -964,6 +978,7 @@ export class ComfyApp {
|
||||
});
|
||||
|
||||
api.addEventListener("execution_start", ({ detail }) => {
|
||||
this.nodeGrids = {}
|
||||
this.runningNodeId = null;
|
||||
this.lastExecutionError = null
|
||||
});
|
||||
@ -988,6 +1003,93 @@ export class ComfyApp {
|
||||
api.init();
|
||||
}
|
||||
|
||||
/*
|
||||
* Based on inputs in the prompt marked as combinatorial,
|
||||
* construct a grid from the results;
|
||||
*/
|
||||
#resolveGrid(outputNode, output, runningPrompt) {
|
||||
let axes = []
|
||||
|
||||
const allImages = output.filter(batch => Array.isArray(batch.images))
|
||||
.flatMap(batch => batch.images)
|
||||
|
||||
if (allImages.length === 0)
|
||||
return null;
|
||||
|
||||
console.error("PROMPT", runningPrompt);
|
||||
console.error("OUTPUT", output);
|
||||
|
||||
const isInputLink = (input) => {
|
||||
return Array.isArray(input)
|
||||
&& input.length === 2
|
||||
&& typeof input[0] === "string"
|
||||
&& typeof input[1] === "number";
|
||||
}
|
||||
|
||||
// Axes closer to the output (executed later) are discovered first
|
||||
const queue = [outputNode]
|
||||
while (queue.length > 0) {
|
||||
const nodeID = queue.pop();
|
||||
const promptInput = runningPrompt.output[nodeID];
|
||||
|
||||
// Ensure input keys are sorted alphanumerically
|
||||
// This is important for the plot to have the same order as
|
||||
// it was executed on the backend
|
||||
let sortedKeys = Object.keys(promptInput.inputs);
|
||||
sortedKeys.sort((a, b) => a.localeCompare(b));
|
||||
|
||||
// Then reverse the order since we're traversing the graph upstream,
|
||||
// so execution order comes out backwards
|
||||
sortedKeys = sortedKeys.reverse();
|
||||
|
||||
for (const inputName of sortedKeys) {
|
||||
const input = promptInput.inputs[inputName];
|
||||
if (typeof input === "object" && "__inputType__" in input) {
|
||||
axes.push({
|
||||
nodeID,
|
||||
inputName,
|
||||
values: input.values
|
||||
})
|
||||
}
|
||||
else if (isInputLink(input)) {
|
||||
const inputNodeID = input[0]
|
||||
queue.push(inputNodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
axes = axes.reverse();
|
||||
|
||||
// Now divide up the image outputs
|
||||
// Each axis will divide the full array of images by N, where N was the
|
||||
// number of combinatorial choices for that axis, and this happens
|
||||
// recursively for each axis
|
||||
|
||||
console.error("AXES", axes)
|
||||
|
||||
// Grid position
|
||||
const currentCoords = Array.from(Array(0))
|
||||
|
||||
let images = allImages.map(i => { return {
|
||||
image: i,
|
||||
coords: []
|
||||
}})
|
||||
|
||||
let factor = 1
|
||||
|
||||
for (const axis of axes) {
|
||||
factor *= axis.values.length;
|
||||
for (const [index, image] of images.entries()) {
|
||||
image.coords.push(Math.floor((index / factor) * axis.values.length) % axis.values.length);
|
||||
}
|
||||
}
|
||||
|
||||
const grid = { axes, images };
|
||||
console.error("GRID", grid);
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
#addKeyboardHandler() {
|
||||
window.addEventListener("keydown", (e) => {
|
||||
this.shiftDown = e.shiftKey;
|
||||
@ -1292,7 +1394,11 @@ export class ComfyApp {
|
||||
const workflow = this.graph.serialize();
|
||||
const output = {};
|
||||
// Process nodes in order of execution
|
||||
for (const node of this.graph.computeExecutionOrder(false)) {
|
||||
|
||||
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
|
||||
const executionOrderIds = executionOrder.map(n => n.id);
|
||||
|
||||
for (const node of executionOrder) {
|
||||
const n = workflow.nodes.find((n) => n.id === node.id);
|
||||
|
||||
if (node.isVirtualNode) {
|
||||
@ -1359,7 +1465,7 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
|
||||
return { workflow, output };
|
||||
return { workflow, output, executionOrder: executionOrderIds };
|
||||
}
|
||||
|
||||
#formatPromptError(error) {
|
||||
@ -1409,6 +1515,7 @@ export class ComfyApp {
|
||||
|
||||
this.#processingQueue = true;
|
||||
this.lastPromptError = null;
|
||||
this.runningPrompt = null;
|
||||
|
||||
try {
|
||||
while (this.#queueItems.length) {
|
||||
@ -1418,8 +1525,10 @@ export class ComfyApp {
|
||||
const p = await this.graphToPrompt();
|
||||
|
||||
try {
|
||||
this.runningPrompt = p;
|
||||
await api.queuePrompt(number, p);
|
||||
} catch (error) {
|
||||
this.runningPrompt = null;
|
||||
const formattedError = this.#formatPromptError(error)
|
||||
this.ui.dialog.show(formattedError);
|
||||
if (error.response) {
|
||||
@ -1528,10 +1637,12 @@ export class ComfyApp {
|
||||
*/
|
||||
clean() {
|
||||
this.nodeOutputs = {};
|
||||
this.nodeGrids = {};
|
||||
this.nodePreviewImages = {}
|
||||
this.lastPromptError = null;
|
||||
this.lastExecutionError = null;
|
||||
this.runningNodeId = null;
|
||||
this.runningPrompt = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user