diff --git a/execution.py b/execution.py index 6893188e5..be8e903e7 100644 --- a/execution.py +++ b/execution.py @@ -133,33 +133,16 @@ def slice_lists_into_dict(d, i): d_new[k] = v[i if len(v) > i else -1] return d_new -def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callback=None): # check if node wants the lists - intput_is_list = False + input_is_list = False if hasattr(obj, "INPUT_IS_LIST"): - intput_is_list = obj.INPUT_IS_LIST - - max_len_input = max([len(x) for x in input_data_all.values()]) - - def format_dict(d): - s = [] - for k,v in d.items(): - st = f"{k}: " - if isinstance(v, list): - st += f"list[len: {len(v)}][" - i = [] - for v2 in v: - i.append(v2.__class__.__name__) - st += ",".join(i) + "]" - else: - st += str(type(v)) - s.append(st) - return "( " + ", ".join(s) + " )" + input_is_list = obj.INPUT_IS_LIST max_len_input = max(len(x) for x in input_data_all.values()) results = [] - if intput_is_list: + if input_is_list: if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**input_data_all)) @@ -168,6 +151,8 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): if allow_interrupt: nodes.before_node_execution() results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) + if callback is not None: + callback(i, max_len_input) return results def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): @@ -175,8 +160,31 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): all_outputs_ui = [] total_batches = len(input_data_all_batches.batches) + total_inner_batches = 0 + for batch in input_data_all_batches.batches: + total_inner_batches += max(len(x) for x in batch.values()) + + inner_totals = 0 + + def send_batch_progress(inner_num): + if server.client_id is not None: + message = { + "node": unique_id, + "prompt_id": prompt_id, + "batch_num": inner_totals + inner_num, + "total_batches": total_inner_batches + } + server.send_sync("batch_progress", message, server.client_id) + + send_batch_progress(0) + for batch_num, batch in enumerate(input_data_all_batches.batches): - return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) + def cb(inner_num, inner_total): + send_batch_progress(inner_num + 1) + + return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb) + + inner_totals += max(len(x) for x in batch.values()) uis = [] results = [] diff --git a/web/scripts/api.js b/web/scripts/api.js index 7897df5bb..04f1cf107 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -108,6 +108,9 @@ class ComfyApi extends EventTarget { case "progress": this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; + case "batch_progress": + this.dispatchEvent(new CustomEvent("batch_progress", { detail: msg.data })); + break; case "executing": this.dispatchEvent(new CustomEvent("executing", { detail: msg.data })); break; diff --git a/web/scripts/app.js b/web/scripts/app.js index 7e9264dbd..ccb844259 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -891,8 +891,22 @@ export class ComfyApp { } if (self.progress && node.id === +self.runningNodeId) { + let offset = 0 + let height = 6; + + if (self.batchProgress) { + offset = 4; + height = 4; + } + ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height); + + if (self.batchProgress) { + ctx.fillStyle = "#3ca2c3"; + ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height); + } + ctx.fillStyle = bgcolor; } @@ -953,8 +967,22 @@ export class ComfyApp { this.graph.setDirtyCanvas(true, false); }); + api.addEventListener("batch_progress", ({ detail }) => { + if (detail.total_batches <= 1) { + this.batchProgress = null; + } + else { + this.batchProgress = { + value: detail.batch_num, + max: detail.total_batches + } + } + this.graph.setDirtyCanvas(true, false); + }); + api.addEventListener("executing", ({ detail }) => { this.progress = null; + this.batchProgress = null; this.runningNodeId = detail.node; this.graph.setDirtyCanvas(true, false); if (detail.node != null) { @@ -975,6 +1003,10 @@ export class ComfyApp { if (node.onExecuted) node.onExecuted(detail.output); } + if (this.batchProgress != null) { + this.batchProgress.value = detail.batch_num + this.batchProgress.max = detail.total_batches + } }); api.addEventListener("execution_start", ({ detail }) => { @@ -1005,7 +1037,7 @@ export class ComfyApp { /* * Based on inputs in the prompt marked as combinatorial, - * construct a grid from the results; + * construct a grid from the results */ #resolveGrid(outputNode, output, runningPrompt) { let axes = [] @@ -1016,8 +1048,8 @@ export class ComfyApp { if (allImages.length === 0) return null; - console.error("PROMPT", runningPrompt); - console.error("OUTPUT", output); + console.log("PROMPT", runningPrompt); + console.log("OUTPUT", output); const isInputLink = (input) => { return Array.isArray(input) @@ -1028,8 +1060,12 @@ export class ComfyApp { // Axes closer to the output (executed later) are discovered first const queue = [outputNode] + const seen = new Set(); while (queue.length > 0) { const nodeID = queue.pop(); + if (seen.has(nodeID)) + continue; + seen.add(nodeID); const promptInput = runningPrompt.output[nodeID]; const nodeClass = promptInput.class_type @@ -1069,7 +1105,7 @@ export class ComfyApp { // number of combinatorial choices for that axis, and this happens // recursively for each axis - console.error("AXES", axes) + console.log("AXES", axes) // Grid position const currentCoords = Array.from(Array(0)) @@ -1093,7 +1129,7 @@ export class ComfyApp { } const grid = { axes, images }; - console.error("GRID", grid); + console.log("GRID", grid); return grid; } @@ -1401,8 +1437,12 @@ export class ComfyApp { async graphToPrompt() { const workflow = this.graph.serialize(); const output = {}; - // Process nodes in order of execution + let totalExecuted = 0; + let totalCombinatorialNodes = 0; + let executionFactor = 1; + + // Process nodes in order of execution const executionOrder = Array.from(this.graph.computeExecutionOrder(false)); const executionOrderIds = executionOrder.map(n => n.id); @@ -1430,7 +1470,13 @@ export class ComfyApp { for (const i in widgets) { const widget = widgets[i]; if (!widget.options || widget.options.serialize !== false) { - inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + const widgetValue = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + inputs[widget.name] = widgetValue; + if (typeof widgetValue === "object" && widgetValue.__inputType__) { + totalCombinatorialNodes += 1; + executionFactor *= widgetValue.values.length; + } + totalExecuted += executionFactor; } } } @@ -1473,7 +1519,15 @@ export class ComfyApp { } } - return { workflow, output, executionOrder: executionOrderIds }; + return { + prompt: { + workflow, + output, + }, + executionOrder: executionOrderIds, + totalCombinatorialNodes, + totalExecuted + }; } #formatPromptError(error) { @@ -1530,7 +1584,12 @@ export class ComfyApp { ({ number, batchCount } = this.#queueItems.pop()); for (let i = 0; i < batchCount; i++) { - const p = await this.graphToPrompt(); + const result = await this.graphToPrompt(); + const warnExecutedAmount = 256; + if (result.totalExecuted > warnExecutedAmount && !confirm("You are about to execute " + result.totalExecuted + " nodes total across " + result.totalCombinatorialNodes + " combinatorial axes. Are you sure you want to do this?")) { + continue + } + const p = result.prompt; try { this.runningPrompt = p; @@ -1647,6 +1706,8 @@ export class ComfyApp { this.nodeOutputs = {}; this.nodeGrids = {}; this.nodePreviewImages = {} + this.progress = null; + this.batchProgress = null; this.lastPromptError = null; this.lastExecutionError = null; this.runningNodeId = null;