From b6c5f6ae9c1f5659ca8f78fa3bb352f82ccb1c70 Mon Sep 17 00:00:00 2001
From: space-nuko <24979496+space-nuko@users.noreply.github.com>
Date: Fri, 9 Jun 2023 16:32:14 -0500
Subject: [PATCH] Fix combinatorial outputs passing
---
execution.py | 73 +++++++++++++++++++++++------
web/extensions/core/showGrid.js | 66 +++++++++++++++++++++-----
web/extensions/core/widgetInputs.js | 18 ++++++-
web/scripts/app.js | 51 +++++++++++---------
4 files changed, 157 insertions(+), 51 deletions(-)
diff --git a/execution.py b/execution.py
index be8e903e7..b55adb87b 100644
--- a/execution.py
+++ b/execution.py
@@ -27,6 +27,13 @@ class CombinatorialBatches:
combinations: List
+def find(d, pred):
+ for i, x in d.items():
+ if pred(x):
+ return i, x
+ return None, None
+
+
def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input
@@ -34,6 +41,7 @@ def get_input_data_batches(input_data_all):
input_to_index = {}
index_to_values = []
+ index_to_axis = {}
index_to_coords = []
# Sort by input name first so the order which batch inputs are applied can
@@ -45,10 +53,20 @@ def get_input_data_batches(input_data_all):
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
- input_to_index[input_name] = i
- index_to_values.append(value["values"])
- index_to_coords.append(list(range(len(value["values"]))))
- i += 1
+ if "axis_id" in value:
+ found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
+ else:
+ found_i = None
+
+ if found_i is not None:
+ input_to_index[input_name] = found_i
+ else:
+ input_to_index[input_name] = i
+ index_to_values.append(value["values"])
+ index_to_coords.append(list(range(len(value["values"]))))
+ if "axis_id" in value:
+ index_to_axis[i] = value["axis_id"]
+ i += 1
if len(index_to_values) == 0:
# No combinatorial options.
@@ -99,11 +117,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# Make the outputs into a list for map-over-list use
# (they are themselves lists so flatten them afterwards)
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
- input_values = flatten(input_values)
+ input_values = { "combinatorial": True, "values": flatten(input_values) }
input_data_all[x] = input_values
elif is_combinatorial_input(input_data):
if required_or_optional:
- input_data_all[x] = { "combinatorial": True, "values": input_data["values"] }
+ input_data_all[x] = {
+ "combinatorial": True,
+ "values": input_data["values"],
+ "axis_id": input_data.get("axis_id")
+ }
else:
if required_or_optional:
input_data_all[x] = [input_data]
@@ -121,6 +143,28 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all_batches = get_input_data_batches(input_data_all)
+ def format_dict(d):
+ s = []
+ for k,v in d.items():
+ st = f"{k}: "
+ if isinstance(v, list):
+ st += f"list[len: {len(v)}]["
+ i = []
+ for v2 in v:
+ i.append(v2.__class__.__name__)
+ st += ",".join(i) + "]"
+ else:
+ st += str(type(v))
+ s.append(st)
+ return "( " + ", ".join(s) + " )"
+
+
+ print("---------------------------------")
+ from pprint import pp
+ for batch in input_data_all_batches.batches:
+ print(format_dict(batch));
+ print("---------------------------------")
+
return input_data_all_batches
def slice_lists_into_dict(d, i):
@@ -152,7 +196,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
if callback is not None:
- callback(i, max_len_input)
+ callback(i + 1, max_len_input)
return results
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
@@ -180,7 +224,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
for batch_num, batch in enumerate(input_data_all_batches.batches):
def cb(inner_num, inner_total):
- send_batch_progress(inner_num + 1)
+ send_batch_progress(inner_num)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
@@ -229,12 +273,11 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
"node": unique_id,
"output": outputs_ui_to_send,
"prompt_id": prompt_id,
- "batch_num": batch_num,
- "total_batches": total_batches
+ "batch_num": inner_totals,
+ "total_batches": total_inner_batches
}
if input_data_all_batches.indices:
message["indices"] = input_data_all_batches.indices[batch_num]
- message["combination"] = input_data_all_batches.combinations[batch_num]
server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui
@@ -276,7 +319,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if input_data_all_batches.indices:
combinations = {
"input_to_index": input_data_all_batches.input_to_index,
- "index_to_values": input_data_all_batches.index_to_values,
"indices": input_data_all_batches.indices
}
mes = {
@@ -306,6 +348,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
+
+ print("!!! Exception during processing !!!")
+ print(traceback.format_exc())
+
input_data_formatted = []
if input_data_all_batches is not None:
d = {}
@@ -321,9 +367,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
output_data_formatted.append(d)
- print("!!! Exception during processing !!!")
- print(traceback.format_exc())
-
error_details = {
"node_id": unique_id,
"exception_message": str(ex),
diff --git a/web/extensions/core/showGrid.js b/web/extensions/core/showGrid.js
index 908e74056..ea178e4db 100644
--- a/web/extensions/core/showGrid.js
+++ b/web/extensions/core/showGrid.js
@@ -2,6 +2,16 @@ import { app } from "/scripts/app.js";
// Show grids from combinatorial outputs
+async function loadImageAsync(imageURL) {
+ return new Promise((resolve) => {
+ const e = new Image();
+ e.setAttribute('crossorigin', 'anonymous');
+ e.addEventListener("load", () => { resolve(e); });
+ e.src = imageURL;
+ return e;
+ });
+}
+
app.registerExtension({
name: "Comfy.ShowGrid",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
@@ -44,9 +54,15 @@ app.registerExtension({
const axisSelectors = rootElem.querySelector(".axis-selectors");
const imageTable = rootElem.querySelector(".image-table");
+ this.imageSize = 512;
+ this.imageWidth = this.imageSize
+ this.imageHeight = this.imageSize
+ this.naturalWidth = this.imageSize
+ this.naturalHeight = this.imageSize
+
const footerHtml = `
-
+
`
const footerElem = this._gridPanel.addHTML(footerHtml, "grid-footer", true);
@@ -99,19 +115,38 @@ app.registerExtension({
});
}
- const refreshGrid = (xAxis, yAxis) => {
+ const refreshGrid = async (xAxis, yAxis) => {
this.xAxis = xAxis;
this.yAxis = yAxis;
this.xAxisData = getAxisData(this.xAxis);
this.yAxisData = getAxisData(this.yAxis);
- selectAxis(false, this.xAxisData.id)
- selectAxis(true, this.yAxisData.id)
+ selectAxis(false, this.xAxisData.selectorID)
+ selectAxis(true, this.yAxisData.selectorID)
if (xAxis === yAxis) {
this.yAxisData = getAxisData(-1);
}
+ this.imageWidth = this.imageSize
+ this.imageHeight = this.imageSize
+ this.naturalWidth = this.imageSize
+ this.naturalHeight = this.imageSize
+
+ const firstImages = getImagesAt(0, 0);
+ if (firstImages.length > 0) {
+ const src = "/view?" + new URLSearchParams(firstImages[0].image).toString() + app.getPreviewFormatParam();
+ const imgElem = await loadImageAsync(src);
+ this.naturalWidth = imgElem.naturalWidth
+ this.naturalHeight = imgElem.naturalHeight
+
+ const ratio = Math.min(this.imageSize / this.naturalWidth, this.imageSize / this.naturalHeight);
+ const newWidth = this.naturalWidth * ratio;
+ const newHeight = this.naturalHeight * ratio;
+ this.imageWidth = newWidth;
+ this.imageHeight = newHeight;
+ }
+
imageTable.innerHTML = "";
const thead = document.createElement("thead")
@@ -162,8 +197,8 @@ app.registerExtension({
const td = document.createElement("td");
const img = document.createElement("img");
- img.style.width = `${this.imageSize}px`
- img.style.height = `${this.imageSize}px`
+ img.style.width = `${this.imageWidth}px`
+ img.style.height = `${this.imageHeight}px`
const gridImages = getImagesAt(x, y);
if (gridImages.length > 0) {
img.src = "/view?" + new URLSearchParams(gridImages[0].image).toString() + app.getPreviewFormatParam();
@@ -187,7 +222,7 @@ app.registerExtension({
group.innerHTML = `${axisName.toUpperCase()} Axis:  `;
const addAxis = (index, axis) => {
- const axisID = `${axisName}-${axis.id}`;
+ const axisID = `${axisName}-${axis.selectorID}`;
const input = document.createElement("input")
input.setAttribute("type", "radio")
@@ -211,7 +246,7 @@ app.registerExtension({
label.innerHTML = String(axis.label);
label.addEventListener("click", () => {
console.warn("SETAXIS", axis);
- selectAxis(isY, axis.id, true);
+ selectAxis(isY, axis.selectorID, true);
})
group.appendChild(input)
@@ -228,17 +263,22 @@ app.registerExtension({
axisSelectors.appendChild(group);
}
- this.imageSize = 256;
-
imageSizeInput.addEventListener("input", () => {
this.imageSize = parseInt(imageSizeInput.value);
+
+ const ratio = Math.min(this.imageSize / natWidth, this.imageSize / natHeight);
+ const newWidth = this.naturalWidth * ratio;
+ const newHeight = this.naturalHeight * ratio;
+ this.imageWidth = newWidth;
+ this.imageHeight = newHeight;
+
for (const img of imageTable.querySelectorAll("img")) {
- img.style.width = `${this.imageSize}px`
- img.style.height = `${this.imageSize}px`
+ img.style.width = `${this.imageWidth}px`
+ img.style.height = `${this.imageHeight}px`
}
})
- refreshGrid(1, 2);
+ refreshGrid(0, Math.min(1, grid.axes.length));
document.body.appendChild(this._gridPanel);
})
diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js
index c89f86b50..cedf7fdd9 100644
--- a/web/extensions/core/widgetInputs.js
+++ b/web/extensions/core/widgetInputs.js
@@ -197,6 +197,7 @@ app.registerExtension({
this.isVirtualNode = true;
this.properties ||= {}
this.properties.valuesType = "single";
+ this.properties.axisName = "";
this.properties.listValue = "";
this.properties.rangeStepBy = 64;
this.properties.rangeSteps = 2;
@@ -228,6 +229,12 @@ app.registerExtension({
}
let values;
+ let axisID = null;
+ let axisName = null;
+ if (this.properties.axisName != "") {
+ axisID = this.id;
+ axisName = this.properties.axisName
+ }
switch (this.properties.valuesType) {
case "list":
@@ -239,13 +246,13 @@ app.registerExtension({
else if (inputType === "FLOAT") {
values = values.map(v => parseFloat(v))
}
- widget.value = { __inputType__: "combinatorial", values: values }
+ widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
break;
case "range":
const isNumberWidget = widget.type === "number" || widget.origType === "number";
if (isNumberWidget) {
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
- widget.value = { __inputType__: "combinatorial", values: values }
+ widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
break;
}
case "single":
@@ -259,6 +266,10 @@ app.registerExtension({
onPropertyChanged(property, value) {
if (property === "valuesType") {
+ const isSingle = value === "single"
+ if (this.axisNameWidget)
+ this.axisNameWidget.disabled = isSingle
+
const isList = value === "list"
if (this.mainWidget)
this.mainWidget.disabled = isList
@@ -366,6 +377,9 @@ app.registerExtension({
this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices });
+ this.axisNameWidget = this.addWidget("text", "Axis Name", this.properties.axisName, "axisName");
+ this.axisNameWidget.disabled = this.properties.valuesType === "single";
+
this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue");
this.listWidget.disabled = this.properties.valuesType !== "list";
diff --git a/web/scripts/app.js b/web/scripts/app.js
index ccb844259..555fe70c5 100644
--- a/web/scripts/app.js
+++ b/web/scripts/app.js
@@ -890,23 +890,28 @@ export class ComfyApp {
ctx.globalAlpha = 1;
}
- if (self.progress && node.id === +self.runningNodeId) {
- let offset = 0
- let height = 6;
+ if (node.id === +self.runningNodeId) {
+ if (self.progress) {
+ let offset = 0
+ let height = 6;
- if (self.batchProgress) {
- offset = 4;
- height = 4;
+ if (self.batchProgress) {
+ offset = 4;
+ height = 4;
+ }
+
+ ctx.fillStyle = "green";
+ ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
+
+ if (self.batchProgress) {
+ ctx.fillStyle = "#3ca2c3";
+ ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
+ }
}
-
- ctx.fillStyle = "green";
- ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
-
- if (self.batchProgress) {
+ else if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3";
- ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
+ ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), 6);
}
-
ctx.fillStyle = bgcolor;
}
@@ -1082,14 +1087,18 @@ export class ComfyApp {
for (const inputName of sortedKeys) {
const input = promptInput.inputs[inputName];
if (typeof input === "object" && "__inputType__" in input) {
- axes.push({
- nodeID,
- nodeClass,
- id: `${nodeID}-${inputName}`.replace(" ", "-"),
- label: `${nodeClass}: ${inputName}`,
- inputName,
- values: input.values
- })
+ if (input.axis_id == null || !axes.some(a => a.id != null && a.id === input.axis_id)) {
+ let label = input.axis_name || `${nodeClass}: ${inputName}`;
+ axes.push({
+ id: input.axis_id,
+ nodeID,
+ nodeClass,
+ selectorID: `${nodeID}-${inputName}`.replace(" ", "-"),
+ label,
+ inputName,
+ values: input.values
+ })
+ }
}
else if (isInputLink(input)) {
const inputNodeID = input[0]