mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 13:20:19 +08:00
Show batch progress
This commit is contained in:
parent
e360f4b05b
commit
a225674fc0
52
execution.py
52
execution.py
@ -133,33 +133,16 @@ def slice_lists_into_dict(d, i):
|
|||||||
d_new[k] = v[i if len(v) > i else -1]
|
d_new[k] = v[i if len(v) > i else -1]
|
||||||
return d_new
|
return d_new
|
||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callback=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
intput_is_list = False
|
input_is_list = False
|
||||||
if hasattr(obj, "INPUT_IS_LIST"):
|
if hasattr(obj, "INPUT_IS_LIST"):
|
||||||
intput_is_list = obj.INPUT_IS_LIST
|
input_is_list = obj.INPUT_IS_LIST
|
||||||
|
|
||||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
|
||||||
|
|
||||||
def format_dict(d):
|
|
||||||
s = []
|
|
||||||
for k,v in d.items():
|
|
||||||
st = f"{k}: "
|
|
||||||
if isinstance(v, list):
|
|
||||||
st += f"list[len: {len(v)}]["
|
|
||||||
i = []
|
|
||||||
for v2 in v:
|
|
||||||
i.append(v2.__class__.__name__)
|
|
||||||
st += ",".join(i) + "]"
|
|
||||||
else:
|
|
||||||
st += str(type(v))
|
|
||||||
s.append(st)
|
|
||||||
return "( " + ", ".join(s) + " )"
|
|
||||||
|
|
||||||
max_len_input = max(len(x) for x in input_data_all.values())
|
max_len_input = max(len(x) for x in input_data_all.values())
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if intput_is_list:
|
if input_is_list:
|
||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
results.append(getattr(obj, func)(**input_data_all))
|
results.append(getattr(obj, func)(**input_data_all))
|
||||||
@ -168,6 +151,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
|||||||
if allow_interrupt:
|
if allow_interrupt:
|
||||||
nodes.before_node_execution()
|
nodes.before_node_execution()
|
||||||
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
|
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
|
||||||
|
if callback is not None:
|
||||||
|
callback(i, max_len_input)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||||
@ -175,8 +160,31 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
|||||||
all_outputs_ui = []
|
all_outputs_ui = []
|
||||||
total_batches = len(input_data_all_batches.batches)
|
total_batches = len(input_data_all_batches.batches)
|
||||||
|
|
||||||
|
total_inner_batches = 0
|
||||||
|
for batch in input_data_all_batches.batches:
|
||||||
|
total_inner_batches += max(len(x) for x in batch.values())
|
||||||
|
|
||||||
|
inner_totals = 0
|
||||||
|
|
||||||
|
def send_batch_progress(inner_num):
|
||||||
|
if server.client_id is not None:
|
||||||
|
message = {
|
||||||
|
"node": unique_id,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"batch_num": inner_totals + inner_num,
|
||||||
|
"total_batches": total_inner_batches
|
||||||
|
}
|
||||||
|
server.send_sync("batch_progress", message, server.client_id)
|
||||||
|
|
||||||
|
send_batch_progress(0)
|
||||||
|
|
||||||
for batch_num, batch in enumerate(input_data_all_batches.batches):
|
for batch_num, batch in enumerate(input_data_all_batches.batches):
|
||||||
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
|
def cb(inner_num, inner_total):
|
||||||
|
send_batch_progress(inner_num + 1)
|
||||||
|
|
||||||
|
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
|
||||||
|
|
||||||
|
inner_totals += max(len(x) for x in batch.values())
|
||||||
|
|
||||||
uis = []
|
uis = []
|
||||||
results = []
|
results = []
|
||||||
|
|||||||
@ -108,6 +108,9 @@ class ComfyApi extends EventTarget {
|
|||||||
case "progress":
|
case "progress":
|
||||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||||
break;
|
break;
|
||||||
|
case "batch_progress":
|
||||||
|
this.dispatchEvent(new CustomEvent("batch_progress", { detail: msg.data }));
|
||||||
|
break;
|
||||||
case "executing":
|
case "executing":
|
||||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
|
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -891,8 +891,22 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (self.progress && node.id === +self.runningNodeId) {
|
if (self.progress && node.id === +self.runningNodeId) {
|
||||||
|
let offset = 0
|
||||||
|
let height = 6;
|
||||||
|
|
||||||
|
if (self.batchProgress) {
|
||||||
|
offset = 4;
|
||||||
|
height = 4;
|
||||||
|
}
|
||||||
|
|
||||||
ctx.fillStyle = "green";
|
ctx.fillStyle = "green";
|
||||||
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
|
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
|
||||||
|
|
||||||
|
if (self.batchProgress) {
|
||||||
|
ctx.fillStyle = "#3ca2c3";
|
||||||
|
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
|
||||||
|
}
|
||||||
|
|
||||||
ctx.fillStyle = bgcolor;
|
ctx.fillStyle = bgcolor;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -953,8 +967,22 @@ export class ComfyApp {
|
|||||||
this.graph.setDirtyCanvas(true, false);
|
this.graph.setDirtyCanvas(true, false);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
api.addEventListener("batch_progress", ({ detail }) => {
|
||||||
|
if (detail.total_batches <= 1) {
|
||||||
|
this.batchProgress = null;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
this.batchProgress = {
|
||||||
|
value: detail.batch_num,
|
||||||
|
max: detail.total_batches
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.graph.setDirtyCanvas(true, false);
|
||||||
|
});
|
||||||
|
|
||||||
api.addEventListener("executing", ({ detail }) => {
|
api.addEventListener("executing", ({ detail }) => {
|
||||||
this.progress = null;
|
this.progress = null;
|
||||||
|
this.batchProgress = null;
|
||||||
this.runningNodeId = detail.node;
|
this.runningNodeId = detail.node;
|
||||||
this.graph.setDirtyCanvas(true, false);
|
this.graph.setDirtyCanvas(true, false);
|
||||||
if (detail.node != null) {
|
if (detail.node != null) {
|
||||||
@ -975,6 +1003,10 @@ export class ComfyApp {
|
|||||||
if (node.onExecuted)
|
if (node.onExecuted)
|
||||||
node.onExecuted(detail.output);
|
node.onExecuted(detail.output);
|
||||||
}
|
}
|
||||||
|
if (this.batchProgress != null) {
|
||||||
|
this.batchProgress.value = detail.batch_num
|
||||||
|
this.batchProgress.max = detail.total_batches
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
api.addEventListener("execution_start", ({ detail }) => {
|
api.addEventListener("execution_start", ({ detail }) => {
|
||||||
@ -1005,7 +1037,7 @@ export class ComfyApp {
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Based on inputs in the prompt marked as combinatorial,
|
* Based on inputs in the prompt marked as combinatorial,
|
||||||
* construct a grid from the results;
|
* construct a grid from the results
|
||||||
*/
|
*/
|
||||||
#resolveGrid(outputNode, output, runningPrompt) {
|
#resolveGrid(outputNode, output, runningPrompt) {
|
||||||
let axes = []
|
let axes = []
|
||||||
@ -1016,8 +1048,8 @@ export class ComfyApp {
|
|||||||
if (allImages.length === 0)
|
if (allImages.length === 0)
|
||||||
return null;
|
return null;
|
||||||
|
|
||||||
console.error("PROMPT", runningPrompt);
|
console.log("PROMPT", runningPrompt);
|
||||||
console.error("OUTPUT", output);
|
console.log("OUTPUT", output);
|
||||||
|
|
||||||
const isInputLink = (input) => {
|
const isInputLink = (input) => {
|
||||||
return Array.isArray(input)
|
return Array.isArray(input)
|
||||||
@ -1028,8 +1060,12 @@ export class ComfyApp {
|
|||||||
|
|
||||||
// Axes closer to the output (executed later) are discovered first
|
// Axes closer to the output (executed later) are discovered first
|
||||||
const queue = [outputNode]
|
const queue = [outputNode]
|
||||||
|
const seen = new Set();
|
||||||
while (queue.length > 0) {
|
while (queue.length > 0) {
|
||||||
const nodeID = queue.pop();
|
const nodeID = queue.pop();
|
||||||
|
if (seen.has(nodeID))
|
||||||
|
continue;
|
||||||
|
seen.add(nodeID);
|
||||||
const promptInput = runningPrompt.output[nodeID];
|
const promptInput = runningPrompt.output[nodeID];
|
||||||
const nodeClass = promptInput.class_type
|
const nodeClass = promptInput.class_type
|
||||||
|
|
||||||
@ -1069,7 +1105,7 @@ export class ComfyApp {
|
|||||||
// number of combinatorial choices for that axis, and this happens
|
// number of combinatorial choices for that axis, and this happens
|
||||||
// recursively for each axis
|
// recursively for each axis
|
||||||
|
|
||||||
console.error("AXES", axes)
|
console.log("AXES", axes)
|
||||||
|
|
||||||
// Grid position
|
// Grid position
|
||||||
const currentCoords = Array.from(Array(0))
|
const currentCoords = Array.from(Array(0))
|
||||||
@ -1093,7 +1129,7 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const grid = { axes, images };
|
const grid = { axes, images };
|
||||||
console.error("GRID", grid);
|
console.log("GRID", grid);
|
||||||
|
|
||||||
return grid;
|
return grid;
|
||||||
}
|
}
|
||||||
@ -1401,8 +1437,12 @@ export class ComfyApp {
|
|||||||
async graphToPrompt() {
|
async graphToPrompt() {
|
||||||
const workflow = this.graph.serialize();
|
const workflow = this.graph.serialize();
|
||||||
const output = {};
|
const output = {};
|
||||||
// Process nodes in order of execution
|
|
||||||
|
|
||||||
|
let totalExecuted = 0;
|
||||||
|
let totalCombinatorialNodes = 0;
|
||||||
|
let executionFactor = 1;
|
||||||
|
|
||||||
|
// Process nodes in order of execution
|
||||||
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
|
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
|
||||||
const executionOrderIds = executionOrder.map(n => n.id);
|
const executionOrderIds = executionOrder.map(n => n.id);
|
||||||
|
|
||||||
@ -1430,7 +1470,13 @@ export class ComfyApp {
|
|||||||
for (const i in widgets) {
|
for (const i in widgets) {
|
||||||
const widget = widgets[i];
|
const widget = widgets[i];
|
||||||
if (!widget.options || widget.options.serialize !== false) {
|
if (!widget.options || widget.options.serialize !== false) {
|
||||||
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
|
const widgetValue = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
|
||||||
|
inputs[widget.name] = widgetValue;
|
||||||
|
if (typeof widgetValue === "object" && widgetValue.__inputType__) {
|
||||||
|
totalCombinatorialNodes += 1;
|
||||||
|
executionFactor *= widgetValue.values.length;
|
||||||
|
}
|
||||||
|
totalExecuted += executionFactor;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1473,7 +1519,15 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return { workflow, output, executionOrder: executionOrderIds };
|
return {
|
||||||
|
prompt: {
|
||||||
|
workflow,
|
||||||
|
output,
|
||||||
|
},
|
||||||
|
executionOrder: executionOrderIds,
|
||||||
|
totalCombinatorialNodes,
|
||||||
|
totalExecuted
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
#formatPromptError(error) {
|
#formatPromptError(error) {
|
||||||
@ -1530,7 +1584,12 @@ export class ComfyApp {
|
|||||||
({ number, batchCount } = this.#queueItems.pop());
|
({ number, batchCount } = this.#queueItems.pop());
|
||||||
|
|
||||||
for (let i = 0; i < batchCount; i++) {
|
for (let i = 0; i < batchCount; i++) {
|
||||||
const p = await this.graphToPrompt();
|
const result = await this.graphToPrompt();
|
||||||
|
const warnExecutedAmount = 256;
|
||||||
|
if (result.totalExecuted > warnExecutedAmount && !confirm("You are about to execute " + result.totalExecuted + " nodes total across " + result.totalCombinatorialNodes + " combinatorial axes. Are you sure you want to do this?")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const p = result.prompt;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
this.runningPrompt = p;
|
this.runningPrompt = p;
|
||||||
@ -1647,6 +1706,8 @@ export class ComfyApp {
|
|||||||
this.nodeOutputs = {};
|
this.nodeOutputs = {};
|
||||||
this.nodeGrids = {};
|
this.nodeGrids = {};
|
||||||
this.nodePreviewImages = {}
|
this.nodePreviewImages = {}
|
||||||
|
this.progress = null;
|
||||||
|
this.batchProgress = null;
|
||||||
this.lastPromptError = null;
|
this.lastPromptError = null;
|
||||||
this.lastExecutionError = null;
|
this.lastExecutionError = null;
|
||||||
this.runningNodeId = null;
|
this.runningNodeId = null;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user