From b6c5f6ae9c1f5659ca8f78fa3bb352f82ccb1c70 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 9 Jun 2023 16:32:14 -0500 Subject: [PATCH] Fix combinatorial outputs passing --- execution.py | 73 +++++++++++++++++++++++------ web/extensions/core/showGrid.js | 66 +++++++++++++++++++++----- web/extensions/core/widgetInputs.js | 18 ++++++- web/scripts/app.js | 51 +++++++++++--------- 4 files changed, 157 insertions(+), 51 deletions(-) diff --git a/execution.py b/execution.py index be8e903e7..b55adb87b 100644 --- a/execution.py +++ b/execution.py @@ -27,6 +27,13 @@ class CombinatorialBatches: combinations: List +def find(d, pred): + for i, x in d.items(): + if pred(x): + return i, x + return None, None + + def get_input_data_batches(input_data_all): """Given input data that can contain combinatorial input values, returns all the possible batches that can be made by combining the different input @@ -34,6 +41,7 @@ def get_input_data_batches(input_data_all): input_to_index = {} index_to_values = [] + index_to_axis = {} index_to_coords = [] # Sort by input name first so the order which batch inputs are applied can @@ -45,10 +53,20 @@ def get_input_data_batches(input_data_all): for input_name in sorted_input_names: value = input_data_all[input_name] if isinstance(value, dict) and "combinatorial" in value: - input_to_index[input_name] = i - index_to_values.append(value["values"]) - index_to_coords.append(list(range(len(value["values"])))) - i += 1 + if "axis_id" in value: + found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None) + else: + found_i = None + + if found_i is not None: + input_to_index[input_name] = found_i + else: + input_to_index[input_name] = i + index_to_values.append(value["values"]) + index_to_coords.append(list(range(len(value["values"])))) + if "axis_id" in value: + index_to_axis[i] = value["axis_id"] + i += 1 if len(index_to_values) == 0: # No combinatorial options. @@ -99,11 +117,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da # Make the outputs into a list for map-over-list use # (they are themselves lists so flatten them afterwards) input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] - input_values = flatten(input_values) + input_values = { "combinatorial": True, "values": flatten(input_values) } input_data_all[x] = input_values elif is_combinatorial_input(input_data): if required_or_optional: - input_data_all[x] = { "combinatorial": True, "values": input_data["values"] } + input_data_all[x] = { + "combinatorial": True, + "values": input_data["values"], + "axis_id": input_data.get("axis_id") + } else: if required_or_optional: input_data_all[x] = [input_data] @@ -121,6 +143,28 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all_batches = get_input_data_batches(input_data_all) + def format_dict(d): + s = [] + for k,v in d.items(): + st = f"{k}: " + if isinstance(v, list): + st += f"list[len: {len(v)}][" + i = [] + for v2 in v: + i.append(v2.__class__.__name__) + st += ",".join(i) + "]" + else: + st += str(type(v)) + s.append(st) + return "( " + ", ".join(s) + " )" + + + print("---------------------------------") + from pprint import pp + for batch in input_data_all_batches.batches: + print(format_dict(batch)); + print("---------------------------------") + return input_data_all_batches def slice_lists_into_dict(d, i): @@ -152,7 +196,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac nodes.before_node_execution() results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) if callback is not None: - callback(i, max_len_input) + callback(i + 1, max_len_input) return results def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): @@ -180,7 +224,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): for batch_num, batch in enumerate(input_data_all_batches.batches): def cb(inner_num, inner_total): - send_batch_progress(inner_num + 1) + send_batch_progress(inner_num) return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb) @@ -229,12 +273,11 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): "node": unique_id, "output": outputs_ui_to_send, "prompt_id": prompt_id, - "batch_num": batch_num, - "total_batches": total_batches + "batch_num": inner_totals, + "total_batches": total_inner_batches } if input_data_all_batches.indices: message["indices"] = input_data_all_batches.indices[batch_num] - message["combination"] = input_data_all_batches.combinations[batch_num] server.send_sync("executed", message, server.client_id) return all_outputs, all_outputs_ui @@ -276,7 +319,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if input_data_all_batches.indices: combinations = { "input_to_index": input_data_all_batches.input_to_index, - "index_to_values": input_data_all_batches.index_to_values, "indices": input_data_all_batches.indices } mes = { @@ -306,6 +348,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + input_data_formatted = [] if input_data_all_batches is not None: d = {} @@ -321,9 +367,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute d[node_id] = [[format_value(x) for x in l] for l in batch_outputs] output_data_formatted.append(d) - print("!!! Exception during processing !!!") - print(traceback.format_exc()) - error_details = { "node_id": unique_id, "exception_message": str(ex), diff --git a/web/extensions/core/showGrid.js b/web/extensions/core/showGrid.js index 908e74056..ea178e4db 100644 --- a/web/extensions/core/showGrid.js +++ b/web/extensions/core/showGrid.js @@ -2,6 +2,16 @@ import { app } from "/scripts/app.js"; // Show grids from combinatorial outputs +async function loadImageAsync(imageURL) { + return new Promise((resolve) => { + const e = new Image(); + e.setAttribute('crossorigin', 'anonymous'); + e.addEventListener("load", () => { resolve(e); }); + e.src = imageURL; + return e; + }); +} + app.registerExtension({ name: "Comfy.ShowGrid", async beforeRegisterNodeDef(nodeType, nodeData, app) { @@ -44,9 +54,15 @@ app.registerExtension({ const axisSelectors = rootElem.querySelector(".axis-selectors"); const imageTable = rootElem.querySelector(".image-table"); + this.imageSize = 512; + this.imageWidth = this.imageSize + this.imageHeight = this.imageSize + this.naturalWidth = this.imageSize + this.naturalHeight = this.imageSize + const footerHtml = ` - + ` const footerElem = this._gridPanel.addHTML(footerHtml, "grid-footer", true); @@ -99,19 +115,38 @@ app.registerExtension({ }); } - const refreshGrid = (xAxis, yAxis) => { + const refreshGrid = async (xAxis, yAxis) => { this.xAxis = xAxis; this.yAxis = yAxis; this.xAxisData = getAxisData(this.xAxis); this.yAxisData = getAxisData(this.yAxis); - selectAxis(false, this.xAxisData.id) - selectAxis(true, this.yAxisData.id) + selectAxis(false, this.xAxisData.selectorID) + selectAxis(true, this.yAxisData.selectorID) if (xAxis === yAxis) { this.yAxisData = getAxisData(-1); } + this.imageWidth = this.imageSize + this.imageHeight = this.imageSize + this.naturalWidth = this.imageSize + this.naturalHeight = this.imageSize + + const firstImages = getImagesAt(0, 0); + if (firstImages.length > 0) { + const src = "/view?" + new URLSearchParams(firstImages[0].image).toString() + app.getPreviewFormatParam(); + const imgElem = await loadImageAsync(src); + this.naturalWidth = imgElem.naturalWidth + this.naturalHeight = imgElem.naturalHeight + + const ratio = Math.min(this.imageSize / this.naturalWidth, this.imageSize / this.naturalHeight); + const newWidth = this.naturalWidth * ratio; + const newHeight = this.naturalHeight * ratio; + this.imageWidth = newWidth; + this.imageHeight = newHeight; + } + imageTable.innerHTML = ""; const thead = document.createElement("thead") @@ -162,8 +197,8 @@ app.registerExtension({ const td = document.createElement("td"); const img = document.createElement("img"); - img.style.width = `${this.imageSize}px` - img.style.height = `${this.imageSize}px` + img.style.width = `${this.imageWidth}px` + img.style.height = `${this.imageHeight}px` const gridImages = getImagesAt(x, y); if (gridImages.length > 0) { img.src = "/view?" + new URLSearchParams(gridImages[0].image).toString() + app.getPreviewFormatParam(); @@ -187,7 +222,7 @@ app.registerExtension({ group.innerHTML = `${axisName.toUpperCase()} Axis:  `; const addAxis = (index, axis) => { - const axisID = `${axisName}-${axis.id}`; + const axisID = `${axisName}-${axis.selectorID}`; const input = document.createElement("input") input.setAttribute("type", "radio") @@ -211,7 +246,7 @@ app.registerExtension({ label.innerHTML = String(axis.label); label.addEventListener("click", () => { console.warn("SETAXIS", axis); - selectAxis(isY, axis.id, true); + selectAxis(isY, axis.selectorID, true); }) group.appendChild(input) @@ -228,17 +263,22 @@ app.registerExtension({ axisSelectors.appendChild(group); } - this.imageSize = 256; - imageSizeInput.addEventListener("input", () => { this.imageSize = parseInt(imageSizeInput.value); + + const ratio = Math.min(this.imageSize / natWidth, this.imageSize / natHeight); + const newWidth = this.naturalWidth * ratio; + const newHeight = this.naturalHeight * ratio; + this.imageWidth = newWidth; + this.imageHeight = newHeight; + for (const img of imageTable.querySelectorAll("img")) { - img.style.width = `${this.imageSize}px` - img.style.height = `${this.imageSize}px` + img.style.width = `${this.imageWidth}px` + img.style.height = `${this.imageHeight}px` } }) - refreshGrid(1, 2); + refreshGrid(0, Math.min(1, grid.axes.length)); document.body.appendChild(this._gridPanel); }) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index c89f86b50..cedf7fdd9 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -197,6 +197,7 @@ app.registerExtension({ this.isVirtualNode = true; this.properties ||= {} this.properties.valuesType = "single"; + this.properties.axisName = ""; this.properties.listValue = ""; this.properties.rangeStepBy = 64; this.properties.rangeSteps = 2; @@ -228,6 +229,12 @@ app.registerExtension({ } let values; + let axisID = null; + let axisName = null; + if (this.properties.axisName != "") { + axisID = this.id; + axisName = this.properties.axisName + } switch (this.properties.valuesType) { case "list": @@ -239,13 +246,13 @@ app.registerExtension({ else if (inputType === "FLOAT") { values = values.map(v => parseFloat(v)) } - widget.value = { __inputType__: "combinatorial", values: values } + widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName } break; case "range": const isNumberWidget = widget.type === "number" || widget.origType === "number"; if (isNumberWidget) { values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps); - widget.value = { __inputType__: "combinatorial", values: values } + widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName } break; } case "single": @@ -259,6 +266,10 @@ app.registerExtension({ onPropertyChanged(property, value) { if (property === "valuesType") { + const isSingle = value === "single" + if (this.axisNameWidget) + this.axisNameWidget.disabled = isSingle + const isList = value === "list" if (this.mainWidget) this.mainWidget.disabled = isList @@ -366,6 +377,9 @@ app.registerExtension({ this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices }); + this.axisNameWidget = this.addWidget("text", "Axis Name", this.properties.axisName, "axisName"); + this.axisNameWidget.disabled = this.properties.valuesType === "single"; + this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue"); this.listWidget.disabled = this.properties.valuesType !== "list"; diff --git a/web/scripts/app.js b/web/scripts/app.js index ccb844259..555fe70c5 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -890,23 +890,28 @@ export class ComfyApp { ctx.globalAlpha = 1; } - if (self.progress && node.id === +self.runningNodeId) { - let offset = 0 - let height = 6; + if (node.id === +self.runningNodeId) { + if (self.progress) { + let offset = 0 + let height = 6; - if (self.batchProgress) { - offset = 4; - height = 4; + if (self.batchProgress) { + offset = 4; + height = 4; + } + + ctx.fillStyle = "green"; + ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height); + + if (self.batchProgress) { + ctx.fillStyle = "#3ca2c3"; + ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height); + } } - - ctx.fillStyle = "green"; - ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height); - - if (self.batchProgress) { + else if (self.batchProgress) { ctx.fillStyle = "#3ca2c3"; - ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height); + ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), 6); } - ctx.fillStyle = bgcolor; } @@ -1082,14 +1087,18 @@ export class ComfyApp { for (const inputName of sortedKeys) { const input = promptInput.inputs[inputName]; if (typeof input === "object" && "__inputType__" in input) { - axes.push({ - nodeID, - nodeClass, - id: `${nodeID}-${inputName}`.replace(" ", "-"), - label: `${nodeClass}: ${inputName}`, - inputName, - values: input.values - }) + if (input.axis_id == null || !axes.some(a => a.id != null && a.id === input.axis_id)) { + let label = input.axis_name || `${nodeClass}: ${inputName}`; + axes.push({ + id: input.axis_id, + nodeID, + nodeClass, + selectorID: `${nodeID}-${inputName}`.replace(" ", "-"), + label, + inputName, + values: input.values + }) + } } else if (isInputLink(input)) { const inputNodeID = input[0]