Show batch progress

This commit is contained in:
space-nuko 2023-06-09 15:01:33 -05:00
parent e360f4b05b
commit a225674fc0
3 changed files with 104 additions and 32 deletions

View File

@ -133,33 +133,16 @@ def slice_lists_into_dict(d, i):
d_new[k] = v[i if len(v) > i else -1]
return d_new
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callback=None):
# check if node wants the lists
intput_is_list = False
input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
def format_dict(d):
s = []
for k,v in d.items():
st = f"{k}: "
if isinstance(v, list):
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
input_is_list = obj.INPUT_IS_LIST
max_len_input = max(len(x) for x in input_data_all.values())
results = []
if intput_is_list:
if input_is_list:
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all))
@ -168,6 +151,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
if allow_interrupt:
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
if callback is not None:
callback(i, max_len_input)
return results
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
@ -175,8 +160,31 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
all_outputs_ui = []
total_batches = len(input_data_all_batches.batches)
total_inner_batches = 0
for batch in input_data_all_batches.batches:
total_inner_batches += max(len(x) for x in batch.values())
inner_totals = 0
def send_batch_progress(inner_num):
if server.client_id is not None:
message = {
"node": unique_id,
"prompt_id": prompt_id,
"batch_num": inner_totals + inner_num,
"total_batches": total_inner_batches
}
server.send_sync("batch_progress", message, server.client_id)
send_batch_progress(0)
for batch_num, batch in enumerate(input_data_all_batches.batches):
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True)
def cb(inner_num, inner_total):
send_batch_progress(inner_num + 1)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
inner_totals += max(len(x) for x in batch.values())
uis = []
results = []

View File

@ -108,6 +108,9 @@ class ComfyApi extends EventTarget {
case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break;
case "batch_progress":
this.dispatchEvent(new CustomEvent("batch_progress", { detail: msg.data }));
break;
case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
break;

View File

@ -891,8 +891,22 @@ export class ComfyApp {
}
if (self.progress && node.id === +self.runningNodeId) {
let offset = 0
let height = 6;
if (self.batchProgress) {
offset = 4;
height = 4;
}
ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6);
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3";
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
}
ctx.fillStyle = bgcolor;
}
@ -953,8 +967,22 @@ export class ComfyApp {
this.graph.setDirtyCanvas(true, false);
});
api.addEventListener("batch_progress", ({ detail }) => {
if (detail.total_batches <= 1) {
this.batchProgress = null;
}
else {
this.batchProgress = {
value: detail.batch_num,
max: detail.total_batches
}
}
this.graph.setDirtyCanvas(true, false);
});
api.addEventListener("executing", ({ detail }) => {
this.progress = null;
this.batchProgress = null;
this.runningNodeId = detail.node;
this.graph.setDirtyCanvas(true, false);
if (detail.node != null) {
@ -975,6 +1003,10 @@ export class ComfyApp {
if (node.onExecuted)
node.onExecuted(detail.output);
}
if (this.batchProgress != null) {
this.batchProgress.value = detail.batch_num
this.batchProgress.max = detail.total_batches
}
});
api.addEventListener("execution_start", ({ detail }) => {
@ -1005,7 +1037,7 @@ export class ComfyApp {
/*
* Based on inputs in the prompt marked as combinatorial,
* construct a grid from the results;
* construct a grid from the results
*/
#resolveGrid(outputNode, output, runningPrompt) {
let axes = []
@ -1016,8 +1048,8 @@ export class ComfyApp {
if (allImages.length === 0)
return null;
console.error("PROMPT", runningPrompt);
console.error("OUTPUT", output);
console.log("PROMPT", runningPrompt);
console.log("OUTPUT", output);
const isInputLink = (input) => {
return Array.isArray(input)
@ -1028,8 +1060,12 @@ export class ComfyApp {
// Axes closer to the output (executed later) are discovered first
const queue = [outputNode]
const seen = new Set();
while (queue.length > 0) {
const nodeID = queue.pop();
if (seen.has(nodeID))
continue;
seen.add(nodeID);
const promptInput = runningPrompt.output[nodeID];
const nodeClass = promptInput.class_type
@ -1069,7 +1105,7 @@ export class ComfyApp {
// number of combinatorial choices for that axis, and this happens
// recursively for each axis
console.error("AXES", axes)
console.log("AXES", axes)
// Grid position
const currentCoords = Array.from(Array(0))
@ -1093,7 +1129,7 @@ export class ComfyApp {
}
const grid = { axes, images };
console.error("GRID", grid);
console.log("GRID", grid);
return grid;
}
@ -1401,8 +1437,12 @@ export class ComfyApp {
async graphToPrompt() {
const workflow = this.graph.serialize();
const output = {};
// Process nodes in order of execution
let totalExecuted = 0;
let totalCombinatorialNodes = 0;
let executionFactor = 1;
// Process nodes in order of execution
const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
const executionOrderIds = executionOrder.map(n => n.id);
@ -1430,7 +1470,13 @@ export class ComfyApp {
for (const i in widgets) {
const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
const widgetValue = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
inputs[widget.name] = widgetValue;
if (typeof widgetValue === "object" && widgetValue.__inputType__) {
totalCombinatorialNodes += 1;
executionFactor *= widgetValue.values.length;
}
totalExecuted += executionFactor;
}
}
}
@ -1473,7 +1519,15 @@ export class ComfyApp {
}
}
return { workflow, output, executionOrder: executionOrderIds };
return {
prompt: {
workflow,
output,
},
executionOrder: executionOrderIds,
totalCombinatorialNodes,
totalExecuted
};
}
#formatPromptError(error) {
@ -1530,7 +1584,12 @@ export class ComfyApp {
({ number, batchCount } = this.#queueItems.pop());
for (let i = 0; i < batchCount; i++) {
const p = await this.graphToPrompt();
const result = await this.graphToPrompt();
const warnExecutedAmount = 256;
if (result.totalExecuted > warnExecutedAmount && !confirm("You are about to execute " + result.totalExecuted + " nodes total across " + result.totalCombinatorialNodes + " combinatorial axes. Are you sure you want to do this?")) {
continue
}
const p = result.prompt;
try {
this.runningPrompt = p;
@ -1647,6 +1706,8 @@ export class ComfyApp {
this.nodeOutputs = {};
this.nodeGrids = {};
this.nodePreviewImages = {}
this.progress = null;
this.batchProgress = null;
this.lastPromptError = null;
this.lastExecutionError = null;
this.runningNodeId = null;