mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 20:30:25 +08:00
Calculate grid from combinatorial inputs
This commit is contained in:
parent
2fa87f1779
commit
c9f4eb3fad
70
execution.py
70
execution.py
@ -8,6 +8,9 @@ import traceback
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import itertools
|
import itertools
|
||||||
|
from typing import List, Dict
|
||||||
|
import dataclasses
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
@ -15,6 +18,15 @@ import nodes
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CombinatorialBatches:
|
||||||
|
batches: List
|
||||||
|
input_to_index: Dict
|
||||||
|
index_to_values: Dict
|
||||||
|
indices: List
|
||||||
|
combinations: List
|
||||||
|
|
||||||
|
|
||||||
def get_input_data_batches(input_data_all):
|
def get_input_data_batches(input_data_all):
|
||||||
"""Given input data that can contain combinatorial input values, returns all
|
"""Given input data that can contain combinatorial input values, returns all
|
||||||
the possible batches that can be made by combining the different input
|
the possible batches that can be made by combining the different input
|
||||||
@ -22,6 +34,7 @@ def get_input_data_batches(input_data_all):
|
|||||||
|
|
||||||
input_to_index = {}
|
input_to_index = {}
|
||||||
index_to_values = []
|
index_to_values = []
|
||||||
|
index_to_coords = []
|
||||||
|
|
||||||
# Sort by input name first so the order which batch inputs are applied can
|
# Sort by input name first so the order which batch inputs are applied can
|
||||||
# be easily calculated (node execution order first, then alphabetical input
|
# be easily calculated (node execution order first, then alphabetical input
|
||||||
@ -34,15 +47,18 @@ def get_input_data_batches(input_data_all):
|
|||||||
if isinstance(value, dict) and "combinatorial" in value:
|
if isinstance(value, dict) and "combinatorial" in value:
|
||||||
input_to_index[input_name] = i
|
input_to_index[input_name] = i
|
||||||
index_to_values.append(value["values"])
|
index_to_values.append(value["values"])
|
||||||
|
index_to_coords.append(list(range(len(value["values"]))))
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if len(index_to_values) == 0:
|
if len(index_to_values) == 0:
|
||||||
# No combinatorial options.
|
# No combinatorial options.
|
||||||
return [input_data_all]
|
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
|
||||||
|
|
||||||
batches = []
|
batches = []
|
||||||
|
|
||||||
for combination in list(itertools.product(*index_to_values)):
|
indices = list(itertools.product(*index_to_coords))
|
||||||
|
combinations = list(itertools.product(*index_to_values))
|
||||||
|
for combination in combinations:
|
||||||
batch = {}
|
batch = {}
|
||||||
for input_name, value in input_data_all.items():
|
for input_name, value in input_data_all.items():
|
||||||
if isinstance(value, dict) and "combinatorial" in value:
|
if isinstance(value, dict) and "combinatorial" in value:
|
||||||
@ -53,7 +69,7 @@ def get_input_data_batches(input_data_all):
|
|||||||
batch[input_name] = value
|
batch[input_name] = value
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
|
|
||||||
return batches
|
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
|
||||||
"""Given input data from the prompt, returns a list of input data dicts for
|
"""Given input data from the prompt, returns a list of input data dicts for
|
||||||
@ -157,9 +173,9 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
all_outputs_ui = []
|
all_outputs_ui = []
|
||||||
total_batches = len(input_data_all_batches)
|
total_batches = len(input_data_all_batches.batches)
|
||||||
|
|
||||||
for batch_num, batch in enumerate(input_data_all_batches):
|
for batch_num, batch in enumerate(input_data_all_batches.batches):
|
||||||
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
|
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
|
||||||
|
|
||||||
uis = []
|
uis = []
|
||||||
@ -208,6 +224,9 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
|||||||
"batch_num": batch_num,
|
"batch_num": batch_num,
|
||||||
"total_batches": total_batches
|
"total_batches": total_batches
|
||||||
}
|
}
|
||||||
|
if input_data_all_batches.indices:
|
||||||
|
message["indices"] = input_data_all_batches.indices[batch_num]
|
||||||
|
message["combination"] = input_data_all_batches.combinations[batch_num]
|
||||||
server.send_sync("executed", message, server.client_id)
|
server.send_sync("executed", message, server.client_id)
|
||||||
|
|
||||||
return all_outputs, all_outputs_ui
|
return all_outputs, all_outputs_ui
|
||||||
@ -240,12 +259,25 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
# Another node failed further upstream
|
# Another node failed further upstream
|
||||||
return result
|
return result
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all_batches = None
|
||||||
try:
|
try:
|
||||||
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = unique_id
|
server.last_node_id = unique_id
|
||||||
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id)
|
combinations = None
|
||||||
|
if input_data_all_batches.indices:
|
||||||
|
combinations = {
|
||||||
|
"input_to_index": input_data_all_batches.input_to_index,
|
||||||
|
"index_to_values": input_data_all_batches.index_to_values,
|
||||||
|
"indices": input_data_all_batches.indices
|
||||||
|
}
|
||||||
|
mes = {
|
||||||
|
"node": unique_id,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"combinations": combinations
|
||||||
|
}
|
||||||
|
server.send_sync("executing", mes, server.client_id)
|
||||||
|
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
|
|
||||||
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
||||||
@ -266,15 +298,20 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
typ, _, tb = sys.exc_info()
|
typ, _, tb = sys.exc_info()
|
||||||
exception_type = full_type_name(typ)
|
exception_type = full_type_name(typ)
|
||||||
input_data_formatted = {}
|
input_data_formatted = []
|
||||||
if input_data_all is not None:
|
if input_data_all_batches is not None:
|
||||||
input_data_formatted = {}
|
d = {}
|
||||||
for name, inputs in input_data_all.items():
|
for batch in input_data_all_batches.batches:
|
||||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
for name, inputs in batch.items():
|
||||||
|
d[name] = [format_value(x) for x in inputs]
|
||||||
|
input_data_formatted.append(d)
|
||||||
|
|
||||||
output_data_formatted = {}
|
output_data_formatted = []
|
||||||
for node_id, node_outputs in outputs.items():
|
for node_id, node_outputs in outputs.items():
|
||||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
d = {}
|
||||||
|
for batch_outputs in node_outputs:
|
||||||
|
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
|
||||||
|
output_data_formatted.append(d)
|
||||||
|
|
||||||
print("!!! Exception during processing !!!")
|
print("!!! Exception during processing !!!")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -327,7 +364,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
if input_data_all_batches is not None:
|
if input_data_all_batches is not None:
|
||||||
try:
|
try:
|
||||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
for batch in input_data_all_batches:
|
for batch in input_data_all_batches.batches:
|
||||||
if map_node_over_list(class_def, batch, "IS_CHANGED"):
|
if map_node_over_list(class_def, batch, "IS_CHANGED"):
|
||||||
is_changed = True
|
is_changed = True
|
||||||
break
|
break
|
||||||
@ -668,8 +705,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
|
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
for batch in input_data_all_batches.batches:
|
||||||
for batch in input_data_all_batches:
|
|
||||||
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
|
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
|
||||||
for r in ret:
|
for r in ret:
|
||||||
if r != True:
|
if r != True:
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class ComfyApi extends EventTarget {
|
|||||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||||
break;
|
break;
|
||||||
case "executing":
|
case "executing":
|
||||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
|
||||||
break;
|
break;
|
||||||
case "executed":
|
case "executed":
|
||||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||||
|
|||||||
@ -44,6 +44,12 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the grid data for each node
|
||||||
|
* @type {Record<string, any>}
|
||||||
|
*/
|
||||||
|
this.nodeGrids = {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stores the preview image data for each node
|
* Stores the preview image data for each node
|
||||||
* @type {Record<string, Image>}
|
* @type {Record<string, Image>}
|
||||||
@ -949,13 +955,21 @@ export class ComfyApp {
|
|||||||
|
|
||||||
api.addEventListener("executing", ({ detail }) => {
|
api.addEventListener("executing", ({ detail }) => {
|
||||||
this.progress = null;
|
this.progress = null;
|
||||||
this.runningNodeId = detail;
|
this.runningNodeId = detail.node;
|
||||||
this.graph.setDirtyCanvas(true, false);
|
this.graph.setDirtyCanvas(true, false);
|
||||||
delete this.nodePreviewImages[this.runningNodeId]
|
if (detail.node != null) {
|
||||||
|
delete this.nodePreviewImages[this.runningNodeId]
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
this.runningPrompt = null;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
api.addEventListener("executed", ({ detail }) => {
|
api.addEventListener("executed", ({ detail }) => {
|
||||||
this.nodeOutputs[detail.node] = detail.output;
|
this.nodeOutputs[detail.node] = detail.output;
|
||||||
|
if (detail.output != null) {
|
||||||
|
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
|
||||||
|
}
|
||||||
const node = this.graph.getNodeById(detail.node);
|
const node = this.graph.getNodeById(detail.node);
|
||||||
if (node) {
|
if (node) {
|
||||||
if (node.onExecuted)
|
if (node.onExecuted)
|
||||||
@ -964,6 +978,7 @@ export class ComfyApp {
|
|||||||
});
|
});
|
||||||
|
|
||||||
api.addEventListener("execution_start", ({ detail }) => {
|
api.addEventListener("execution_start", ({ detail }) => {
|
||||||
|
this.nodeGrids = {}
|
||||||
this.runningNodeId = null;
|
this.runningNodeId = null;
|
||||||
this.lastExecutionError = null
|
this.lastExecutionError = null
|
||||||
});
|
});
|
||||||
@ -988,6 +1003,93 @@ export class ComfyApp {
|
|||||||
api.init();
|
api.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Based on inputs in the prompt marked as combinatorial,
|
||||||
|
* construct a grid from the results;
|
||||||
|
*/
|
||||||
|
#resolveGrid(outputNode, output, runningPrompt) {
|
||||||
|
let axes = []
|
||||||
|
|
||||||
|
const allImages = output.filter(batch => Array.isArray(batch.images))
|
||||||
|
.flatMap(batch => batch.images)
|
||||||
|
|
||||||
|
if (allImages.length === 0)
|
||||||
|
return null;
|
||||||
|
|
||||||
|
console.error("PROMPT", runningPrompt);
|
||||||
|
console.error("OUTPUT", output);
|
||||||
|
|
||||||
|
const isInputLink = (input) => {
|
||||||
|
return Array.isArray(input)
|
||||||
|
&& input.length === 2
|
||||||
|
&& typeof input[0] === "string"
|
||||||
|
&& typeof input[1] === "number";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Axes closer to the output (executed later) are discovered first
|
||||||
|
const queue = [outputNode]
|
||||||
|
while (queue.length > 0) {
|
||||||
|
const nodeID = queue.pop();
|
||||||
|
const promptInput = runningPrompt.output[nodeID];
|
||||||
|
|
||||||
|
// Ensure input keys are sorted alphanumerically
|
||||||
|
// This is important for the plot to have the same order as
|
||||||
|
// it was executed on the backend
|
||||||
|
let sortedKeys = Object.keys(promptInput.inputs);
|
||||||
|
sortedKeys.sort((a, b) => a.localeCompare(b));
|
||||||
|
|
||||||
|
// Then reverse the order since we're traversing the graph upstream,
|
||||||
|
// so execution order comes out backwards
|
||||||
|
sortedKeys = sortedKeys.reverse();
|
||||||
|
|
||||||
|
for (const inputName of sortedKeys) {
|
||||||
|
const input = promptInput.inputs[inputName];
|
||||||
|
if (typeof input === "object" && "__inputType__" in input) {
|
||||||
|
axes.push({
|
||||||
|
nodeID,
|
||||||
|
inputName,
|
||||||
|
values: input.values
|
||||||
|
})
|
||||||
|
}
|
||||||
|
else if (isInputLink(input)) {
|
||||||
|
const inputNodeID = input[0]
|
||||||
|
queue.push(inputNodeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
axes = axes.reverse();
|
||||||
|
|
||||||
|
// Now divide up the image outputs
|
||||||
|
// Each axis will divide the full array of images by N, where N was the
|
||||||
|
// number of combinatorial choices for that axis, and this happens
|
||||||
|
// recursively for each axis
|
||||||
|
|
||||||
|
console.error("AXES", axes)
|
||||||
|
|
||||||
|
// Grid position
|
||||||
|
const currentCoords = Array.from(Array(0))
|
||||||
|
|
||||||
|
let images = allImages.map(i => { return {
|
||||||
|
image: i,
|
||||||
|
coords: []
|
||||||
|
}})
|
||||||
|
|
||||||
|
let factor = 1
|
||||||
|
|
||||||
|
for (const axis of axes) {
|
||||||
|
factor *= axis.values.length;
|
||||||
|
for (const [index, image] of images.entries()) {
|
||||||
|
image.coords.push(Math.floor((index / factor) * axis.values.length) % axis.values.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const grid = { axes, images };
|
||||||
|
console.error("GRID", grid);
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
#addKeyboardHandler() {
|
#addKeyboardHandler() {
|
||||||
window.addEventListener("keydown", (e) => {
|
window.addEventListener("keydown", (e) => {
|
||||||
this.shiftDown = e.shiftKey;
|
this.shiftDown = e.shiftKey;
|
||||||
@ -1292,7 +1394,11 @@ export class ComfyApp {
|
|||||||
const workflow = this.graph.serialize();
|
const workflow = this.graph.serialize();
|
||||||
const output = {};
|
const output = {};
|
||||||
// Process nodes in order of execution
|
// Process nodes in order of execution
|
||||||
for (const node of this.graph.computeExecutionOrder(false)) {
|
|
||||||
|
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
|
||||||
|
const executionOrderIds = executionOrder.map(n => n.id);
|
||||||
|
|
||||||
|
for (const node of executionOrder) {
|
||||||
const n = workflow.nodes.find((n) => n.id === node.id);
|
const n = workflow.nodes.find((n) => n.id === node.id);
|
||||||
|
|
||||||
if (node.isVirtualNode) {
|
if (node.isVirtualNode) {
|
||||||
@ -1359,7 +1465,7 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { workflow, output };
|
return { workflow, output, executionOrder: executionOrderIds };
|
||||||
}
|
}
|
||||||
|
|
||||||
#formatPromptError(error) {
|
#formatPromptError(error) {
|
||||||
@ -1409,6 +1515,7 @@ export class ComfyApp {
|
|||||||
|
|
||||||
this.#processingQueue = true;
|
this.#processingQueue = true;
|
||||||
this.lastPromptError = null;
|
this.lastPromptError = null;
|
||||||
|
this.runningPrompt = null;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (this.#queueItems.length) {
|
while (this.#queueItems.length) {
|
||||||
@ -1418,8 +1525,10 @@ export class ComfyApp {
|
|||||||
const p = await this.graphToPrompt();
|
const p = await this.graphToPrompt();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
this.runningPrompt = p;
|
||||||
await api.queuePrompt(number, p);
|
await api.queuePrompt(number, p);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
this.runningPrompt = null;
|
||||||
const formattedError = this.#formatPromptError(error)
|
const formattedError = this.#formatPromptError(error)
|
||||||
this.ui.dialog.show(formattedError);
|
this.ui.dialog.show(formattedError);
|
||||||
if (error.response) {
|
if (error.response) {
|
||||||
@ -1528,10 +1637,12 @@ export class ComfyApp {
|
|||||||
*/
|
*/
|
||||||
clean() {
|
clean() {
|
||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
|
this.nodeGrids = {};
|
||||||
this.nodePreviewImages = {}
|
this.nodePreviewImages = {}
|
||||||
this.lastPromptError = null;
|
this.lastPromptError = null;
|
||||||
this.lastExecutionError = null;
|
this.lastExecutionError = null;
|
||||||
this.runningNodeId = null;
|
this.runningNodeId = null;
|
||||||
|
this.runningPrompt = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user