Fix combinatorial outputs passing

This commit is contained in:
space-nuko 2023-06-09 16:32:14 -05:00
parent a225674fc0
commit b6c5f6ae9c
4 changed files with 157 additions and 51 deletions

View File

@ -27,6 +27,13 @@ class CombinatorialBatches:
combinations: List combinations: List
def find(d, pred):
for i, x in d.items():
if pred(x):
return i, x
return None, None
def get_input_data_batches(input_data_all): def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all """Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input the possible batches that can be made by combining the different input
@ -34,6 +41,7 @@ def get_input_data_batches(input_data_all):
input_to_index = {} input_to_index = {}
index_to_values = [] index_to_values = []
index_to_axis = {}
index_to_coords = [] index_to_coords = []
# Sort by input name first so the order which batch inputs are applied can # Sort by input name first so the order which batch inputs are applied can
@ -45,10 +53,20 @@ def get_input_data_batches(input_data_all):
for input_name in sorted_input_names: for input_name in sorted_input_names:
value = input_data_all[input_name] value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value: if isinstance(value, dict) and "combinatorial" in value:
input_to_index[input_name] = i if "axis_id" in value:
index_to_values.append(value["values"]) found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
index_to_coords.append(list(range(len(value["values"])))) else:
i += 1 found_i = None
if found_i is not None:
input_to_index[input_name] = found_i
else:
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
if "axis_id" in value:
index_to_axis[i] = value["axis_id"]
i += 1
if len(index_to_values) == 0: if len(index_to_values) == 0:
# No combinatorial options. # No combinatorial options.
@ -99,11 +117,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# Make the outputs into a list for map-over-list use # Make the outputs into a list for map-over-list use
# (they are themselves lists so flatten them afterwards) # (they are themselves lists so flatten them afterwards)
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
input_values = flatten(input_values) input_values = { "combinatorial": True, "values": flatten(input_values) }
input_data_all[x] = input_values input_data_all[x] = input_values
elif is_combinatorial_input(input_data): elif is_combinatorial_input(input_data):
if required_or_optional: if required_or_optional:
input_data_all[x] = { "combinatorial": True, "values": input_data["values"] } input_data_all[x] = {
"combinatorial": True,
"values": input_data["values"],
"axis_id": input_data.get("axis_id")
}
else: else:
if required_or_optional: if required_or_optional:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
@ -121,6 +143,28 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all_batches = get_input_data_batches(input_data_all) input_data_all_batches = get_input_data_batches(input_data_all)
def format_dict(d):
s = []
for k,v in d.items():
st = f"{k}: "
if isinstance(v, list):
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
print("---------------------------------")
from pprint import pp
for batch in input_data_all_batches.batches:
print(format_dict(batch));
print("---------------------------------")
return input_data_all_batches return input_data_all_batches
def slice_lists_into_dict(d, i): def slice_lists_into_dict(d, i):
@ -152,7 +196,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
nodes.before_node_execution() nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
if callback is not None: if callback is not None:
callback(i, max_len_input) callback(i + 1, max_len_input)
return results return results
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
@ -180,7 +224,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
for batch_num, batch in enumerate(input_data_all_batches.batches): for batch_num, batch in enumerate(input_data_all_batches.batches):
def cb(inner_num, inner_total): def cb(inner_num, inner_total):
send_batch_progress(inner_num + 1) send_batch_progress(inner_num)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb) return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
@ -229,12 +273,11 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
"node": unique_id, "node": unique_id,
"output": outputs_ui_to_send, "output": outputs_ui_to_send,
"prompt_id": prompt_id, "prompt_id": prompt_id,
"batch_num": batch_num, "batch_num": inner_totals,
"total_batches": total_batches "total_batches": total_inner_batches
} }
if input_data_all_batches.indices: if input_data_all_batches.indices:
message["indices"] = input_data_all_batches.indices[batch_num] message["indices"] = input_data_all_batches.indices[batch_num]
message["combination"] = input_data_all_batches.combinations[batch_num]
server.send_sync("executed", message, server.client_id) server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui return all_outputs, all_outputs_ui
@ -276,7 +319,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if input_data_all_batches.indices: if input_data_all_batches.indices:
combinations = { combinations = {
"input_to_index": input_data_all_batches.input_to_index, "input_to_index": input_data_all_batches.input_to_index,
"index_to_values": input_data_all_batches.index_to_values,
"indices": input_data_all_batches.indices "indices": input_data_all_batches.indices
} }
mes = { mes = {
@ -306,6 +348,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
except Exception as ex: except Exception as ex:
typ, _, tb = sys.exc_info() typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ) exception_type = full_type_name(typ)
print("!!! Exception during processing !!!")
print(traceback.format_exc())
input_data_formatted = [] input_data_formatted = []
if input_data_all_batches is not None: if input_data_all_batches is not None:
d = {} d = {}
@ -321,9 +367,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs] d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
output_data_formatted.append(d) output_data_formatted.append(d)
print("!!! Exception during processing !!!")
print(traceback.format_exc())
error_details = { error_details = {
"node_id": unique_id, "node_id": unique_id,
"exception_message": str(ex), "exception_message": str(ex),

View File

@ -2,6 +2,16 @@ import { app } from "/scripts/app.js";
// Show grids from combinatorial outputs // Show grids from combinatorial outputs
async function loadImageAsync(imageURL) {
return new Promise((resolve) => {
const e = new Image();
e.setAttribute('crossorigin', 'anonymous');
e.addEventListener("load", () => { resolve(e); });
e.src = imageURL;
return e;
});
}
app.registerExtension({ app.registerExtension({
name: "Comfy.ShowGrid", name: "Comfy.ShowGrid",
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
@ -44,9 +54,15 @@ app.registerExtension({
const axisSelectors = rootElem.querySelector(".axis-selectors"); const axisSelectors = rootElem.querySelector(".axis-selectors");
const imageTable = rootElem.querySelector(".image-table"); const imageTable = rootElem.querySelector(".image-table");
this.imageSize = 512;
this.imageWidth = this.imageSize
this.imageHeight = this.imageSize
this.naturalWidth = this.imageSize
this.naturalHeight = this.imageSize
const footerHtml = ` const footerHtml = `
<label for="image-size">Image size</label> <label for="image-size">Image size</label>
<input class="image-size" id="image-size" type="range" min="64" max="1024" step="1" value="512"> <input class="image-size" id="image-size" type="range" min="64" max="1024" step="1" value="${this.imageSize}">
</input> </input>
` `
const footerElem = this._gridPanel.addHTML(footerHtml, "grid-footer", true); const footerElem = this._gridPanel.addHTML(footerHtml, "grid-footer", true);
@ -99,19 +115,38 @@ app.registerExtension({
}); });
} }
const refreshGrid = (xAxis, yAxis) => { const refreshGrid = async (xAxis, yAxis) => {
this.xAxis = xAxis; this.xAxis = xAxis;
this.yAxis = yAxis; this.yAxis = yAxis;
this.xAxisData = getAxisData(this.xAxis); this.xAxisData = getAxisData(this.xAxis);
this.yAxisData = getAxisData(this.yAxis); this.yAxisData = getAxisData(this.yAxis);
selectAxis(false, this.xAxisData.id) selectAxis(false, this.xAxisData.selectorID)
selectAxis(true, this.yAxisData.id) selectAxis(true, this.yAxisData.selectorID)
if (xAxis === yAxis) { if (xAxis === yAxis) {
this.yAxisData = getAxisData(-1); this.yAxisData = getAxisData(-1);
} }
this.imageWidth = this.imageSize
this.imageHeight = this.imageSize
this.naturalWidth = this.imageSize
this.naturalHeight = this.imageSize
const firstImages = getImagesAt(0, 0);
if (firstImages.length > 0) {
const src = "/view?" + new URLSearchParams(firstImages[0].image).toString() + app.getPreviewFormatParam();
const imgElem = await loadImageAsync(src);
this.naturalWidth = imgElem.naturalWidth
this.naturalHeight = imgElem.naturalHeight
const ratio = Math.min(this.imageSize / this.naturalWidth, this.imageSize / this.naturalHeight);
const newWidth = this.naturalWidth * ratio;
const newHeight = this.naturalHeight * ratio;
this.imageWidth = newWidth;
this.imageHeight = newHeight;
}
imageTable.innerHTML = ""; imageTable.innerHTML = "";
const thead = document.createElement("thead") const thead = document.createElement("thead")
@ -162,8 +197,8 @@ app.registerExtension({
const td = document.createElement("td"); const td = document.createElement("td");
const img = document.createElement("img"); const img = document.createElement("img");
img.style.width = `${this.imageSize}px` img.style.width = `${this.imageWidth}px`
img.style.height = `${this.imageSize}px` img.style.height = `${this.imageHeight}px`
const gridImages = getImagesAt(x, y); const gridImages = getImagesAt(x, y);
if (gridImages.length > 0) { if (gridImages.length > 0) {
img.src = "/view?" + new URLSearchParams(gridImages[0].image).toString() + app.getPreviewFormatParam(); img.src = "/view?" + new URLSearchParams(gridImages[0].image).toString() + app.getPreviewFormatParam();
@ -187,7 +222,7 @@ app.registerExtension({
group.innerHTML = `${axisName.toUpperCase()} Axis:&nbsp `; group.innerHTML = `${axisName.toUpperCase()} Axis:&nbsp `;
const addAxis = (index, axis) => { const addAxis = (index, axis) => {
const axisID = `${axisName}-${axis.id}`; const axisID = `${axisName}-${axis.selectorID}`;
const input = document.createElement("input") const input = document.createElement("input")
input.setAttribute("type", "radio") input.setAttribute("type", "radio")
@ -211,7 +246,7 @@ app.registerExtension({
label.innerHTML = String(axis.label); label.innerHTML = String(axis.label);
label.addEventListener("click", () => { label.addEventListener("click", () => {
console.warn("SETAXIS", axis); console.warn("SETAXIS", axis);
selectAxis(isY, axis.id, true); selectAxis(isY, axis.selectorID, true);
}) })
group.appendChild(input) group.appendChild(input)
@ -228,17 +263,22 @@ app.registerExtension({
axisSelectors.appendChild(group); axisSelectors.appendChild(group);
} }
this.imageSize = 256;
imageSizeInput.addEventListener("input", () => { imageSizeInput.addEventListener("input", () => {
this.imageSize = parseInt(imageSizeInput.value); this.imageSize = parseInt(imageSizeInput.value);
const ratio = Math.min(this.imageSize / natWidth, this.imageSize / natHeight);
const newWidth = this.naturalWidth * ratio;
const newHeight = this.naturalHeight * ratio;
this.imageWidth = newWidth;
this.imageHeight = newHeight;
for (const img of imageTable.querySelectorAll("img")) { for (const img of imageTable.querySelectorAll("img")) {
img.style.width = `${this.imageSize}px` img.style.width = `${this.imageWidth}px`
img.style.height = `${this.imageSize}px` img.style.height = `${this.imageHeight}px`
} }
}) })
refreshGrid(1, 2); refreshGrid(0, Math.min(1, grid.axes.length));
document.body.appendChild(this._gridPanel); document.body.appendChild(this._gridPanel);
}) })

View File

@ -197,6 +197,7 @@ app.registerExtension({
this.isVirtualNode = true; this.isVirtualNode = true;
this.properties ||= {} this.properties ||= {}
this.properties.valuesType = "single"; this.properties.valuesType = "single";
this.properties.axisName = "";
this.properties.listValue = ""; this.properties.listValue = "";
this.properties.rangeStepBy = 64; this.properties.rangeStepBy = 64;
this.properties.rangeSteps = 2; this.properties.rangeSteps = 2;
@ -228,6 +229,12 @@ app.registerExtension({
} }
let values; let values;
let axisID = null;
let axisName = null;
if (this.properties.axisName != "") {
axisID = this.id;
axisName = this.properties.axisName
}
switch (this.properties.valuesType) { switch (this.properties.valuesType) {
case "list": case "list":
@ -239,13 +246,13 @@ app.registerExtension({
else if (inputType === "FLOAT") { else if (inputType === "FLOAT") {
values = values.map(v => parseFloat(v)) values = values.map(v => parseFloat(v))
} }
widget.value = { __inputType__: "combinatorial", values: values } widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
break; break;
case "range": case "range":
const isNumberWidget = widget.type === "number" || widget.origType === "number"; const isNumberWidget = widget.type === "number" || widget.origType === "number";
if (isNumberWidget) { if (isNumberWidget) {
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps); values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
widget.value = { __inputType__: "combinatorial", values: values } widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
break; break;
} }
case "single": case "single":
@ -259,6 +266,10 @@ app.registerExtension({
onPropertyChanged(property, value) { onPropertyChanged(property, value) {
if (property === "valuesType") { if (property === "valuesType") {
const isSingle = value === "single"
if (this.axisNameWidget)
this.axisNameWidget.disabled = isSingle
const isList = value === "list" const isList = value === "list"
if (this.mainWidget) if (this.mainWidget)
this.mainWidget.disabled = isList this.mainWidget.disabled = isList
@ -366,6 +377,9 @@ app.registerExtension({
this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices }); this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices });
this.axisNameWidget = this.addWidget("text", "Axis Name", this.properties.axisName, "axisName");
this.axisNameWidget.disabled = this.properties.valuesType === "single";
this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue"); this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue");
this.listWidget.disabled = this.properties.valuesType !== "list"; this.listWidget.disabled = this.properties.valuesType !== "list";

View File

@ -890,23 +890,28 @@ export class ComfyApp {
ctx.globalAlpha = 1; ctx.globalAlpha = 1;
} }
if (self.progress && node.id === +self.runningNodeId) { if (node.id === +self.runningNodeId) {
let offset = 0 if (self.progress) {
let height = 6; let offset = 0
let height = 6;
if (self.batchProgress) { if (self.batchProgress) {
offset = 4; offset = 4;
height = 4; height = 4;
}
ctx.fillStyle = "green";
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3";
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
}
} }
else if (self.batchProgress) {
ctx.fillStyle = "green";
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3"; ctx.fillStyle = "#3ca2c3";
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height); ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), 6);
} }
ctx.fillStyle = bgcolor; ctx.fillStyle = bgcolor;
} }
@ -1082,14 +1087,18 @@ export class ComfyApp {
for (const inputName of sortedKeys) { for (const inputName of sortedKeys) {
const input = promptInput.inputs[inputName]; const input = promptInput.inputs[inputName];
if (typeof input === "object" && "__inputType__" in input) { if (typeof input === "object" && "__inputType__" in input) {
axes.push({ if (input.axis_id == null || !axes.some(a => a.id != null && a.id === input.axis_id)) {
nodeID, let label = input.axis_name || `${nodeClass}: ${inputName}`;
nodeClass, axes.push({
id: `${nodeID}-${inputName}`.replace(" ", "-"), id: input.axis_id,
label: `${nodeClass}: ${inputName}`, nodeID,
inputName, nodeClass,
values: input.values selectorID: `${nodeID}-${inputName}`.replace(" ", "-"),
}) label,
inputName,
values: input.values
})
}
} }
else if (isInputLink(input)) { else if (isInputLink(input)) {
const inputNodeID = input[0] const inputNodeID = input[0]