Show batch progress

This commit is contained in:
space-nuko 2023-06-09 15:01:33 -05:00
parent e360f4b05b
commit a225674fc0
3 changed files with 104 additions and 32 deletions

View File

@ -133,33 +133,16 @@ def slice_lists_into_dict(d, i):
d_new[k] = v[i if len(v) > i else -1] d_new[k] = v[i if len(v) > i else -1]
return d_new return d_new
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callback=None):
# check if node wants the lists # check if node wants the lists
intput_is_list = False input_is_list = False
if hasattr(obj, "INPUT_IS_LIST"): if hasattr(obj, "INPUT_IS_LIST"):
intput_is_list = obj.INPUT_IS_LIST input_is_list = obj.INPUT_IS_LIST
max_len_input = max([len(x) for x in input_data_all.values()])
def format_dict(d):
s = []
for k,v in d.items():
st = f"{k}: "
if isinstance(v, list):
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
max_len_input = max(len(x) for x in input_data_all.values()) max_len_input = max(len(x) for x in input_data_all.values())
results = [] results = []
if intput_is_list: if input_is_list:
if allow_interrupt: if allow_interrupt:
nodes.before_node_execution() nodes.before_node_execution()
results.append(getattr(obj, func)(**input_data_all)) results.append(getattr(obj, func)(**input_data_all))
@ -168,6 +151,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
if allow_interrupt: if allow_interrupt:
nodes.before_node_execution() nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
if callback is not None:
callback(i, max_len_input)
return results return results
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
@ -175,8 +160,31 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
all_outputs_ui = [] all_outputs_ui = []
total_batches = len(input_data_all_batches.batches) total_batches = len(input_data_all_batches.batches)
total_inner_batches = 0
for batch in input_data_all_batches.batches:
total_inner_batches += max(len(x) for x in batch.values())
inner_totals = 0
def send_batch_progress(inner_num):
if server.client_id is not None:
message = {
"node": unique_id,
"prompt_id": prompt_id,
"batch_num": inner_totals + inner_num,
"total_batches": total_inner_batches
}
server.send_sync("batch_progress", message, server.client_id)
send_batch_progress(0)
for batch_num, batch in enumerate(input_data_all_batches.batches): for batch_num, batch in enumerate(input_data_all_batches.batches):
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) def cb(inner_num, inner_total):
send_batch_progress(inner_num + 1)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
inner_totals += max(len(x) for x in batch.values())
uis = [] uis = []
results = [] results = []

View File

@ -108,6 +108,9 @@ class ComfyApi extends EventTarget {
case "progress": case "progress":
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
break; break;
case "batch_progress":
this.dispatchEvent(new CustomEvent("batch_progress", { detail: msg.data }));
break;
case "executing": case "executing":
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data })); this.dispatchEvent(new CustomEvent("executing", { detail: msg.data }));
break; break;

View File

@ -891,8 +891,22 @@ export class ComfyApp {
} }
if (self.progress && node.id === +self.runningNodeId) { if (self.progress && node.id === +self.runningNodeId) {
let offset = 0
let height = 6;
if (self.batchProgress) {
offset = 4;
height = 4;
}
ctx.fillStyle = "green"; ctx.fillStyle = "green";
ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3";
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
}
ctx.fillStyle = bgcolor; ctx.fillStyle = bgcolor;
} }
@ -953,8 +967,22 @@ export class ComfyApp {
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
}); });
api.addEventListener("batch_progress", ({ detail }) => {
if (detail.total_batches <= 1) {
this.batchProgress = null;
}
else {
this.batchProgress = {
value: detail.batch_num,
max: detail.total_batches
}
}
this.graph.setDirtyCanvas(true, false);
});
api.addEventListener("executing", ({ detail }) => { api.addEventListener("executing", ({ detail }) => {
this.progress = null; this.progress = null;
this.batchProgress = null;
this.runningNodeId = detail.node; this.runningNodeId = detail.node;
this.graph.setDirtyCanvas(true, false); this.graph.setDirtyCanvas(true, false);
if (detail.node != null) { if (detail.node != null) {
@ -975,6 +1003,10 @@ export class ComfyApp {
if (node.onExecuted) if (node.onExecuted)
node.onExecuted(detail.output); node.onExecuted(detail.output);
} }
if (this.batchProgress != null) {
this.batchProgress.value = detail.batch_num
this.batchProgress.max = detail.total_batches
}
}); });
api.addEventListener("execution_start", ({ detail }) => { api.addEventListener("execution_start", ({ detail }) => {
@ -1005,7 +1037,7 @@ export class ComfyApp {
/* /*
* Based on inputs in the prompt marked as combinatorial, * Based on inputs in the prompt marked as combinatorial,
* construct a grid from the results; * construct a grid from the results
*/ */
#resolveGrid(outputNode, output, runningPrompt) { #resolveGrid(outputNode, output, runningPrompt) {
let axes = [] let axes = []
@ -1016,8 +1048,8 @@ export class ComfyApp {
if (allImages.length === 0) if (allImages.length === 0)
return null; return null;
console.error("PROMPT", runningPrompt); console.log("PROMPT", runningPrompt);
console.error("OUTPUT", output); console.log("OUTPUT", output);
const isInputLink = (input) => { const isInputLink = (input) => {
return Array.isArray(input) return Array.isArray(input)
@ -1028,8 +1060,12 @@ export class ComfyApp {
// Axes closer to the output (executed later) are discovered first // Axes closer to the output (executed later) are discovered first
const queue = [outputNode] const queue = [outputNode]
const seen = new Set();
while (queue.length > 0) { while (queue.length > 0) {
const nodeID = queue.pop(); const nodeID = queue.pop();
if (seen.has(nodeID))
continue;
seen.add(nodeID);
const promptInput = runningPrompt.output[nodeID]; const promptInput = runningPrompt.output[nodeID];
const nodeClass = promptInput.class_type const nodeClass = promptInput.class_type
@ -1069,7 +1105,7 @@ export class ComfyApp {
// number of combinatorial choices for that axis, and this happens // number of combinatorial choices for that axis, and this happens
// recursively for each axis // recursively for each axis
console.error("AXES", axes) console.log("AXES", axes)
// Grid position // Grid position
const currentCoords = Array.from(Array(0)) const currentCoords = Array.from(Array(0))
@ -1093,7 +1129,7 @@ export class ComfyApp {
} }
const grid = { axes, images }; const grid = { axes, images };
console.error("GRID", grid); console.log("GRID", grid);
return grid; return grid;
} }
@ -1401,8 +1437,12 @@ export class ComfyApp {
async graphToPrompt() { async graphToPrompt() {
const workflow = this.graph.serialize(); const workflow = this.graph.serialize();
const output = {}; const output = {};
// Process nodes in order of execution
let totalExecuted = 0;
let totalCombinatorialNodes = 0;
let executionFactor = 1;
// Process nodes in order of execution
const executionOrder = Array.from(this.graph.computeExecutionOrder(false)); const executionOrder = Array.from(this.graph.computeExecutionOrder(false));
const executionOrderIds = executionOrder.map(n => n.id); const executionOrderIds = executionOrder.map(n => n.id);
@ -1430,7 +1470,13 @@ export class ComfyApp {
for (const i in widgets) { for (const i in widgets) {
const widget = widgets[i]; const widget = widgets[i];
if (!widget.options || widget.options.serialize !== false) { if (!widget.options || widget.options.serialize !== false) {
inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; const widgetValue = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value;
inputs[widget.name] = widgetValue;
if (typeof widgetValue === "object" && widgetValue.__inputType__) {
totalCombinatorialNodes += 1;
executionFactor *= widgetValue.values.length;
}
totalExecuted += executionFactor;
} }
} }
} }
@ -1473,7 +1519,15 @@ export class ComfyApp {
} }
} }
return { workflow, output, executionOrder: executionOrderIds }; return {
prompt: {
workflow,
output,
},
executionOrder: executionOrderIds,
totalCombinatorialNodes,
totalExecuted
};
} }
#formatPromptError(error) { #formatPromptError(error) {
@ -1530,7 +1584,12 @@ export class ComfyApp {
({ number, batchCount } = this.#queueItems.pop()); ({ number, batchCount } = this.#queueItems.pop());
for (let i = 0; i < batchCount; i++) { for (let i = 0; i < batchCount; i++) {
const p = await this.graphToPrompt(); const result = await this.graphToPrompt();
const warnExecutedAmount = 256;
if (result.totalExecuted > warnExecutedAmount && !confirm("You are about to execute " + result.totalExecuted + " nodes total across " + result.totalCombinatorialNodes + " combinatorial axes. Are you sure you want to do this?")) {
continue
}
const p = result.prompt;
try { try {
this.runningPrompt = p; this.runningPrompt = p;
@ -1647,6 +1706,8 @@ export class ComfyApp {
this.nodeOutputs = {}; this.nodeOutputs = {};
this.nodeGrids = {}; this.nodeGrids = {};
this.nodePreviewImages = {} this.nodePreviewImages = {}
this.progress = null;
this.batchProgress = null;
this.lastPromptError = null; this.lastPromptError = null;
this.lastExecutionError = null; this.lastExecutionError = null;
this.runningNodeId = null; this.runningNodeId = null;