Fix combinatorial outputs passing

This commit is contained in:
space-nuko 2023-06-09 16:32:14 -05:00
parent a225674fc0
commit b6c5f6ae9c
4 changed files with 157 additions and 51 deletions

View File

@ -27,6 +27,13 @@ class CombinatorialBatches:
combinations: List
def find(d, pred):
for i, x in d.items():
if pred(x):
return i, x
return None, None
def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input
@ -34,6 +41,7 @@ def get_input_data_batches(input_data_all):
input_to_index = {}
index_to_values = []
index_to_axis = {}
index_to_coords = []
# Sort by input name first so the order which batch inputs are applied can
@ -45,10 +53,20 @@ def get_input_data_batches(input_data_all):
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
i += 1
if "axis_id" in value:
found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
else:
found_i = None
if found_i is not None:
input_to_index[input_name] = found_i
else:
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
if "axis_id" in value:
index_to_axis[i] = value["axis_id"]
i += 1
if len(index_to_values) == 0:
# No combinatorial options.
@ -99,11 +117,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# Make the outputs into a list for map-over-list use
# (they are themselves lists so flatten them afterwards)
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
input_values = flatten(input_values)
input_values = { "combinatorial": True, "values": flatten(input_values) }
input_data_all[x] = input_values
elif is_combinatorial_input(input_data):
if required_or_optional:
input_data_all[x] = { "combinatorial": True, "values": input_data["values"] }
input_data_all[x] = {
"combinatorial": True,
"values": input_data["values"],
"axis_id": input_data.get("axis_id")
}
else:
if required_or_optional:
input_data_all[x] = [input_data]
@ -121,6 +143,28 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all_batches = get_input_data_batches(input_data_all)
def format_dict(d):
s = []
for k,v in d.items():
st = f"{k}: "
if isinstance(v, list):
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
print("---------------------------------")
from pprint import pp
for batch in input_data_all_batches.batches:
print(format_dict(batch));
print("---------------------------------")
return input_data_all_batches
def slice_lists_into_dict(d, i):
@ -152,7 +196,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
nodes.before_node_execution()
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
if callback is not None:
callback(i, max_len_input)
callback(i + 1, max_len_input)
return results
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
@ -180,7 +224,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
for batch_num, batch in enumerate(input_data_all_batches.batches):
def cb(inner_num, inner_total):
send_batch_progress(inner_num + 1)
send_batch_progress(inner_num)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
@ -229,12 +273,11 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
"node": unique_id,
"output": outputs_ui_to_send,
"prompt_id": prompt_id,
"batch_num": batch_num,
"total_batches": total_batches
"batch_num": inner_totals,
"total_batches": total_inner_batches
}
if input_data_all_batches.indices:
message["indices"] = input_data_all_batches.indices[batch_num]
message["combination"] = input_data_all_batches.combinations[batch_num]
server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui
@ -276,7 +319,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if input_data_all_batches.indices:
combinations = {
"input_to_index": input_data_all_batches.input_to_index,
"index_to_values": input_data_all_batches.index_to_values,
"indices": input_data_all_batches.indices
}
mes = {
@ -306,6 +348,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
except Exception as ex:
typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ)
print("!!! Exception during processing !!!")
print(traceback.format_exc())
input_data_formatted = []
if input_data_all_batches is not None:
d = {}
@ -321,9 +367,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
output_data_formatted.append(d)
print("!!! Exception during processing !!!")
print(traceback.format_exc())
error_details = {
"node_id": unique_id,
"exception_message": str(ex),

View File

@ -2,6 +2,16 @@ import { app } from "/scripts/app.js";
// Show grids from combinatorial outputs
async function loadImageAsync(imageURL) {
return new Promise((resolve) => {
const e = new Image();
e.setAttribute('crossorigin', 'anonymous');
e.addEventListener("load", () => { resolve(e); });
e.src = imageURL;
return e;
});
}
app.registerExtension({
name: "Comfy.ShowGrid",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
@ -44,9 +54,15 @@ app.registerExtension({
const axisSelectors = rootElem.querySelector(".axis-selectors");
const imageTable = rootElem.querySelector(".image-table");
this.imageSize = 512;
this.imageWidth = this.imageSize
this.imageHeight = this.imageSize
this.naturalWidth = this.imageSize
this.naturalHeight = this.imageSize
const footerHtml = `
<label for="image-size">Image size</label>
<input class="image-size" id="image-size" type="range" min="64" max="1024" step="1" value="512">
<input class="image-size" id="image-size" type="range" min="64" max="1024" step="1" value="${this.imageSize}">
</input>
`
const footerElem = this._gridPanel.addHTML(footerHtml, "grid-footer", true);
@ -99,19 +115,38 @@ app.registerExtension({
});
}
const refreshGrid = (xAxis, yAxis) => {
const refreshGrid = async (xAxis, yAxis) => {
this.xAxis = xAxis;
this.yAxis = yAxis;
this.xAxisData = getAxisData(this.xAxis);
this.yAxisData = getAxisData(this.yAxis);
selectAxis(false, this.xAxisData.id)
selectAxis(true, this.yAxisData.id)
selectAxis(false, this.xAxisData.selectorID)
selectAxis(true, this.yAxisData.selectorID)
if (xAxis === yAxis) {
this.yAxisData = getAxisData(-1);
}
this.imageWidth = this.imageSize
this.imageHeight = this.imageSize
this.naturalWidth = this.imageSize
this.naturalHeight = this.imageSize
const firstImages = getImagesAt(0, 0);
if (firstImages.length > 0) {
const src = "/view?" + new URLSearchParams(firstImages[0].image).toString() + app.getPreviewFormatParam();
const imgElem = await loadImageAsync(src);
this.naturalWidth = imgElem.naturalWidth
this.naturalHeight = imgElem.naturalHeight
const ratio = Math.min(this.imageSize / this.naturalWidth, this.imageSize / this.naturalHeight);
const newWidth = this.naturalWidth * ratio;
const newHeight = this.naturalHeight * ratio;
this.imageWidth = newWidth;
this.imageHeight = newHeight;
}
imageTable.innerHTML = "";
const thead = document.createElement("thead")
@ -162,8 +197,8 @@ app.registerExtension({
const td = document.createElement("td");
const img = document.createElement("img");
img.style.width = `${this.imageSize}px`
img.style.height = `${this.imageSize}px`
img.style.width = `${this.imageWidth}px`
img.style.height = `${this.imageHeight}px`
const gridImages = getImagesAt(x, y);
if (gridImages.length > 0) {
img.src = "/view?" + new URLSearchParams(gridImages[0].image).toString() + app.getPreviewFormatParam();
@ -187,7 +222,7 @@ app.registerExtension({
group.innerHTML = `${axisName.toUpperCase()} Axis:&nbsp `;
const addAxis = (index, axis) => {
const axisID = `${axisName}-${axis.id}`;
const axisID = `${axisName}-${axis.selectorID}`;
const input = document.createElement("input")
input.setAttribute("type", "radio")
@ -211,7 +246,7 @@ app.registerExtension({
label.innerHTML = String(axis.label);
label.addEventListener("click", () => {
console.warn("SETAXIS", axis);
selectAxis(isY, axis.id, true);
selectAxis(isY, axis.selectorID, true);
})
group.appendChild(input)
@ -228,17 +263,22 @@ app.registerExtension({
axisSelectors.appendChild(group);
}
this.imageSize = 256;
imageSizeInput.addEventListener("input", () => {
this.imageSize = parseInt(imageSizeInput.value);
const ratio = Math.min(this.imageSize / natWidth, this.imageSize / natHeight);
const newWidth = this.naturalWidth * ratio;
const newHeight = this.naturalHeight * ratio;
this.imageWidth = newWidth;
this.imageHeight = newHeight;
for (const img of imageTable.querySelectorAll("img")) {
img.style.width = `${this.imageSize}px`
img.style.height = `${this.imageSize}px`
img.style.width = `${this.imageWidth}px`
img.style.height = `${this.imageHeight}px`
}
})
refreshGrid(1, 2);
refreshGrid(0, Math.min(1, grid.axes.length));
document.body.appendChild(this._gridPanel);
})

View File

@ -197,6 +197,7 @@ app.registerExtension({
this.isVirtualNode = true;
this.properties ||= {}
this.properties.valuesType = "single";
this.properties.axisName = "";
this.properties.listValue = "";
this.properties.rangeStepBy = 64;
this.properties.rangeSteps = 2;
@ -228,6 +229,12 @@ app.registerExtension({
}
let values;
let axisID = null;
let axisName = null;
if (this.properties.axisName != "") {
axisID = this.id;
axisName = this.properties.axisName
}
switch (this.properties.valuesType) {
case "list":
@ -239,13 +246,13 @@ app.registerExtension({
else if (inputType === "FLOAT") {
values = values.map(v => parseFloat(v))
}
widget.value = { __inputType__: "combinatorial", values: values }
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
break;
case "range":
const isNumberWidget = widget.type === "number" || widget.origType === "number";
if (isNumberWidget) {
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
widget.value = { __inputType__: "combinatorial", values: values }
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
break;
}
case "single":
@ -259,6 +266,10 @@ app.registerExtension({
onPropertyChanged(property, value) {
if (property === "valuesType") {
const isSingle = value === "single"
if (this.axisNameWidget)
this.axisNameWidget.disabled = isSingle
const isList = value === "list"
if (this.mainWidget)
this.mainWidget.disabled = isList
@ -366,6 +377,9 @@ app.registerExtension({
this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices });
this.axisNameWidget = this.addWidget("text", "Axis Name", this.properties.axisName, "axisName");
this.axisNameWidget.disabled = this.properties.valuesType === "single";
this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue");
this.listWidget.disabled = this.properties.valuesType !== "list";

View File

@ -890,23 +890,28 @@ export class ComfyApp {
ctx.globalAlpha = 1;
}
if (self.progress && node.id === +self.runningNodeId) {
let offset = 0
let height = 6;
if (node.id === +self.runningNodeId) {
if (self.progress) {
let offset = 0
let height = 6;
if (self.batchProgress) {
offset = 4;
height = 4;
if (self.batchProgress) {
offset = 4;
height = 4;
}
ctx.fillStyle = "green";
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3";
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
}
}
ctx.fillStyle = "green";
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
if (self.batchProgress) {
else if (self.batchProgress) {
ctx.fillStyle = "#3ca2c3";
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), 6);
}
ctx.fillStyle = bgcolor;
}
@ -1082,14 +1087,18 @@ export class ComfyApp {
for (const inputName of sortedKeys) {
const input = promptInput.inputs[inputName];
if (typeof input === "object" && "__inputType__" in input) {
axes.push({
nodeID,
nodeClass,
id: `${nodeID}-${inputName}`.replace(" ", "-"),
label: `${nodeClass}: ${inputName}`,
inputName,
values: input.values
})
if (input.axis_id == null || !axes.some(a => a.id != null && a.id === input.axis_id)) {
let label = input.axis_name || `${nodeClass}: ${inputName}`;
axes.push({
id: input.axis_id,
nodeID,
nodeClass,
selectorID: `${nodeID}-${inputName}`.replace(" ", "-"),
label,
inputName,
values: input.values
})
}
}
else if (isInputLink(input)) {
const inputNodeID = input[0]