mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Fix combinatorial outputs passing
This commit is contained in:
parent
a225674fc0
commit
b6c5f6ae9c
73
execution.py
73
execution.py
@ -27,6 +27,13 @@ class CombinatorialBatches:
|
||||
combinations: List
|
||||
|
||||
|
||||
def find(d, pred):
|
||||
for i, x in d.items():
|
||||
if pred(x):
|
||||
return i, x
|
||||
return None, None
|
||||
|
||||
|
||||
def get_input_data_batches(input_data_all):
|
||||
"""Given input data that can contain combinatorial input values, returns all
|
||||
the possible batches that can be made by combining the different input
|
||||
@ -34,6 +41,7 @@ def get_input_data_batches(input_data_all):
|
||||
|
||||
input_to_index = {}
|
||||
index_to_values = []
|
||||
index_to_axis = {}
|
||||
index_to_coords = []
|
||||
|
||||
# Sort by input name first so the order which batch inputs are applied can
|
||||
@ -45,10 +53,20 @@ def get_input_data_batches(input_data_all):
|
||||
for input_name in sorted_input_names:
|
||||
value = input_data_all[input_name]
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
input_to_index[input_name] = i
|
||||
index_to_values.append(value["values"])
|
||||
index_to_coords.append(list(range(len(value["values"]))))
|
||||
i += 1
|
||||
if "axis_id" in value:
|
||||
found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
|
||||
else:
|
||||
found_i = None
|
||||
|
||||
if found_i is not None:
|
||||
input_to_index[input_name] = found_i
|
||||
else:
|
||||
input_to_index[input_name] = i
|
||||
index_to_values.append(value["values"])
|
||||
index_to_coords.append(list(range(len(value["values"]))))
|
||||
if "axis_id" in value:
|
||||
index_to_axis[i] = value["axis_id"]
|
||||
i += 1
|
||||
|
||||
if len(index_to_values) == 0:
|
||||
# No combinatorial options.
|
||||
@ -99,11 +117,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
# Make the outputs into a list for map-over-list use
|
||||
# (they are themselves lists so flatten them afterwards)
|
||||
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
|
||||
input_values = flatten(input_values)
|
||||
input_values = { "combinatorial": True, "values": flatten(input_values) }
|
||||
input_data_all[x] = input_values
|
||||
elif is_combinatorial_input(input_data):
|
||||
if required_or_optional:
|
||||
input_data_all[x] = { "combinatorial": True, "values": input_data["values"] }
|
||||
input_data_all[x] = {
|
||||
"combinatorial": True,
|
||||
"values": input_data["values"],
|
||||
"axis_id": input_data.get("axis_id")
|
||||
}
|
||||
else:
|
||||
if required_or_optional:
|
||||
input_data_all[x] = [input_data]
|
||||
@ -121,6 +143,28 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
|
||||
input_data_all_batches = get_input_data_batches(input_data_all)
|
||||
|
||||
def format_dict(d):
|
||||
s = []
|
||||
for k,v in d.items():
|
||||
st = f"{k}: "
|
||||
if isinstance(v, list):
|
||||
st += f"list[len: {len(v)}]["
|
||||
i = []
|
||||
for v2 in v:
|
||||
i.append(v2.__class__.__name__)
|
||||
st += ",".join(i) + "]"
|
||||
else:
|
||||
st += str(type(v))
|
||||
s.append(st)
|
||||
return "( " + ", ".join(s) + " )"
|
||||
|
||||
|
||||
print("---------------------------------")
|
||||
from pprint import pp
|
||||
for batch in input_data_all_batches.batches:
|
||||
print(format_dict(batch));
|
||||
print("---------------------------------")
|
||||
|
||||
return input_data_all_batches
|
||||
|
||||
def slice_lists_into_dict(d, i):
|
||||
@ -152,7 +196,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
|
||||
nodes.before_node_execution()
|
||||
results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i)))
|
||||
if callback is not None:
|
||||
callback(i, max_len_input)
|
||||
callback(i + 1, max_len_input)
|
||||
return results
|
||||
|
||||
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
@ -180,7 +224,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
|
||||
for batch_num, batch in enumerate(input_data_all_batches.batches):
|
||||
def cb(inner_num, inner_total):
|
||||
send_batch_progress(inner_num + 1)
|
||||
send_batch_progress(inner_num)
|
||||
|
||||
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
|
||||
|
||||
@ -229,12 +273,11 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
"node": unique_id,
|
||||
"output": outputs_ui_to_send,
|
||||
"prompt_id": prompt_id,
|
||||
"batch_num": batch_num,
|
||||
"total_batches": total_batches
|
||||
"batch_num": inner_totals,
|
||||
"total_batches": total_inner_batches
|
||||
}
|
||||
if input_data_all_batches.indices:
|
||||
message["indices"] = input_data_all_batches.indices[batch_num]
|
||||
message["combination"] = input_data_all_batches.combinations[batch_num]
|
||||
server.send_sync("executed", message, server.client_id)
|
||||
|
||||
return all_outputs, all_outputs_ui
|
||||
@ -276,7 +319,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
if input_data_all_batches.indices:
|
||||
combinations = {
|
||||
"input_to_index": input_data_all_batches.input_to_index,
|
||||
"index_to_values": input_data_all_batches.index_to_values,
|
||||
"indices": input_data_all_batches.indices
|
||||
}
|
||||
mes = {
|
||||
@ -306,6 +348,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
|
||||
input_data_formatted = []
|
||||
if input_data_all_batches is not None:
|
||||
d = {}
|
||||
@ -321,9 +367,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
d[node_id] = [[format_value(x) for x in l] for l in batch_outputs]
|
||||
output_data_formatted.append(d)
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
"exception_message": str(ex),
|
||||
|
||||
@ -2,6 +2,16 @@ import { app } from "/scripts/app.js";
|
||||
|
||||
// Show grids from combinatorial outputs
|
||||
|
||||
async function loadImageAsync(imageURL) {
|
||||
return new Promise((resolve) => {
|
||||
const e = new Image();
|
||||
e.setAttribute('crossorigin', 'anonymous');
|
||||
e.addEventListener("load", () => { resolve(e); });
|
||||
e.src = imageURL;
|
||||
return e;
|
||||
});
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.ShowGrid",
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
@ -44,9 +54,15 @@ app.registerExtension({
|
||||
const axisSelectors = rootElem.querySelector(".axis-selectors");
|
||||
const imageTable = rootElem.querySelector(".image-table");
|
||||
|
||||
this.imageSize = 512;
|
||||
this.imageWidth = this.imageSize
|
||||
this.imageHeight = this.imageSize
|
||||
this.naturalWidth = this.imageSize
|
||||
this.naturalHeight = this.imageSize
|
||||
|
||||
const footerHtml = `
|
||||
<label for="image-size">Image size</label>
|
||||
<input class="image-size" id="image-size" type="range" min="64" max="1024" step="1" value="512">
|
||||
<input class="image-size" id="image-size" type="range" min="64" max="1024" step="1" value="${this.imageSize}">
|
||||
</input>
|
||||
`
|
||||
const footerElem = this._gridPanel.addHTML(footerHtml, "grid-footer", true);
|
||||
@ -99,19 +115,38 @@ app.registerExtension({
|
||||
});
|
||||
}
|
||||
|
||||
const refreshGrid = (xAxis, yAxis) => {
|
||||
const refreshGrid = async (xAxis, yAxis) => {
|
||||
this.xAxis = xAxis;
|
||||
this.yAxis = yAxis;
|
||||
this.xAxisData = getAxisData(this.xAxis);
|
||||
this.yAxisData = getAxisData(this.yAxis);
|
||||
|
||||
selectAxis(false, this.xAxisData.id)
|
||||
selectAxis(true, this.yAxisData.id)
|
||||
selectAxis(false, this.xAxisData.selectorID)
|
||||
selectAxis(true, this.yAxisData.selectorID)
|
||||
|
||||
if (xAxis === yAxis) {
|
||||
this.yAxisData = getAxisData(-1);
|
||||
}
|
||||
|
||||
this.imageWidth = this.imageSize
|
||||
this.imageHeight = this.imageSize
|
||||
this.naturalWidth = this.imageSize
|
||||
this.naturalHeight = this.imageSize
|
||||
|
||||
const firstImages = getImagesAt(0, 0);
|
||||
if (firstImages.length > 0) {
|
||||
const src = "/view?" + new URLSearchParams(firstImages[0].image).toString() + app.getPreviewFormatParam();
|
||||
const imgElem = await loadImageAsync(src);
|
||||
this.naturalWidth = imgElem.naturalWidth
|
||||
this.naturalHeight = imgElem.naturalHeight
|
||||
|
||||
const ratio = Math.min(this.imageSize / this.naturalWidth, this.imageSize / this.naturalHeight);
|
||||
const newWidth = this.naturalWidth * ratio;
|
||||
const newHeight = this.naturalHeight * ratio;
|
||||
this.imageWidth = newWidth;
|
||||
this.imageHeight = newHeight;
|
||||
}
|
||||
|
||||
imageTable.innerHTML = "";
|
||||
|
||||
const thead = document.createElement("thead")
|
||||
@ -162,8 +197,8 @@ app.registerExtension({
|
||||
const td = document.createElement("td");
|
||||
|
||||
const img = document.createElement("img");
|
||||
img.style.width = `${this.imageSize}px`
|
||||
img.style.height = `${this.imageSize}px`
|
||||
img.style.width = `${this.imageWidth}px`
|
||||
img.style.height = `${this.imageHeight}px`
|
||||
const gridImages = getImagesAt(x, y);
|
||||
if (gridImages.length > 0) {
|
||||
img.src = "/view?" + new URLSearchParams(gridImages[0].image).toString() + app.getPreviewFormatParam();
|
||||
@ -187,7 +222,7 @@ app.registerExtension({
|
||||
group.innerHTML = `${axisName.toUpperCase()} Axis:  `;
|
||||
|
||||
const addAxis = (index, axis) => {
|
||||
const axisID = `${axisName}-${axis.id}`;
|
||||
const axisID = `${axisName}-${axis.selectorID}`;
|
||||
|
||||
const input = document.createElement("input")
|
||||
input.setAttribute("type", "radio")
|
||||
@ -211,7 +246,7 @@ app.registerExtension({
|
||||
label.innerHTML = String(axis.label);
|
||||
label.addEventListener("click", () => {
|
||||
console.warn("SETAXIS", axis);
|
||||
selectAxis(isY, axis.id, true);
|
||||
selectAxis(isY, axis.selectorID, true);
|
||||
})
|
||||
|
||||
group.appendChild(input)
|
||||
@ -228,17 +263,22 @@ app.registerExtension({
|
||||
axisSelectors.appendChild(group);
|
||||
}
|
||||
|
||||
this.imageSize = 256;
|
||||
|
||||
imageSizeInput.addEventListener("input", () => {
|
||||
this.imageSize = parseInt(imageSizeInput.value);
|
||||
|
||||
const ratio = Math.min(this.imageSize / natWidth, this.imageSize / natHeight);
|
||||
const newWidth = this.naturalWidth * ratio;
|
||||
const newHeight = this.naturalHeight * ratio;
|
||||
this.imageWidth = newWidth;
|
||||
this.imageHeight = newHeight;
|
||||
|
||||
for (const img of imageTable.querySelectorAll("img")) {
|
||||
img.style.width = `${this.imageSize}px`
|
||||
img.style.height = `${this.imageSize}px`
|
||||
img.style.width = `${this.imageWidth}px`
|
||||
img.style.height = `${this.imageHeight}px`
|
||||
}
|
||||
})
|
||||
|
||||
refreshGrid(1, 2);
|
||||
refreshGrid(0, Math.min(1, grid.axes.length));
|
||||
|
||||
document.body.appendChild(this._gridPanel);
|
||||
})
|
||||
|
||||
@ -197,6 +197,7 @@ app.registerExtension({
|
||||
this.isVirtualNode = true;
|
||||
this.properties ||= {}
|
||||
this.properties.valuesType = "single";
|
||||
this.properties.axisName = "";
|
||||
this.properties.listValue = "";
|
||||
this.properties.rangeStepBy = 64;
|
||||
this.properties.rangeSteps = 2;
|
||||
@ -228,6 +229,12 @@ app.registerExtension({
|
||||
}
|
||||
|
||||
let values;
|
||||
let axisID = null;
|
||||
let axisName = null;
|
||||
if (this.properties.axisName != "") {
|
||||
axisID = this.id;
|
||||
axisName = this.properties.axisName
|
||||
}
|
||||
|
||||
switch (this.properties.valuesType) {
|
||||
case "list":
|
||||
@ -239,13 +246,13 @@ app.registerExtension({
|
||||
else if (inputType === "FLOAT") {
|
||||
values = values.map(v => parseFloat(v))
|
||||
}
|
||||
widget.value = { __inputType__: "combinatorial", values: values }
|
||||
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
|
||||
break;
|
||||
case "range":
|
||||
const isNumberWidget = widget.type === "number" || widget.origType === "number";
|
||||
if (isNumberWidget) {
|
||||
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
|
||||
widget.value = { __inputType__: "combinatorial", values: values }
|
||||
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
|
||||
break;
|
||||
}
|
||||
case "single":
|
||||
@ -259,6 +266,10 @@ app.registerExtension({
|
||||
|
||||
onPropertyChanged(property, value) {
|
||||
if (property === "valuesType") {
|
||||
const isSingle = value === "single"
|
||||
if (this.axisNameWidget)
|
||||
this.axisNameWidget.disabled = isSingle
|
||||
|
||||
const isList = value === "list"
|
||||
if (this.mainWidget)
|
||||
this.mainWidget.disabled = isList
|
||||
@ -366,6 +377,9 @@ app.registerExtension({
|
||||
|
||||
this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices });
|
||||
|
||||
this.axisNameWidget = this.addWidget("text", "Axis Name", this.properties.axisName, "axisName");
|
||||
this.axisNameWidget.disabled = this.properties.valuesType === "single";
|
||||
|
||||
this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue");
|
||||
this.listWidget.disabled = this.properties.valuesType !== "list";
|
||||
|
||||
|
||||
@ -890,23 +890,28 @@ export class ComfyApp {
|
||||
ctx.globalAlpha = 1;
|
||||
}
|
||||
|
||||
if (self.progress && node.id === +self.runningNodeId) {
|
||||
let offset = 0
|
||||
let height = 6;
|
||||
if (node.id === +self.runningNodeId) {
|
||||
if (self.progress) {
|
||||
let offset = 0
|
||||
let height = 6;
|
||||
|
||||
if (self.batchProgress) {
|
||||
offset = 4;
|
||||
height = 4;
|
||||
if (self.batchProgress) {
|
||||
offset = 4;
|
||||
height = 4;
|
||||
}
|
||||
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
|
||||
|
||||
if (self.batchProgress) {
|
||||
ctx.fillStyle = "#3ca2c3";
|
||||
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.fillStyle = "green";
|
||||
ctx.fillRect(0, offset, size[0] * (self.progress.value / self.progress.max), height);
|
||||
|
||||
if (self.batchProgress) {
|
||||
else if (self.batchProgress) {
|
||||
ctx.fillStyle = "#3ca2c3";
|
||||
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), height);
|
||||
ctx.fillRect(0, 0, size[0] * (self.batchProgress.value / self.batchProgress.max), 6);
|
||||
}
|
||||
|
||||
ctx.fillStyle = bgcolor;
|
||||
}
|
||||
|
||||
@ -1082,14 +1087,18 @@ export class ComfyApp {
|
||||
for (const inputName of sortedKeys) {
|
||||
const input = promptInput.inputs[inputName];
|
||||
if (typeof input === "object" && "__inputType__" in input) {
|
||||
axes.push({
|
||||
nodeID,
|
||||
nodeClass,
|
||||
id: `${nodeID}-${inputName}`.replace(" ", "-"),
|
||||
label: `${nodeClass}: ${inputName}`,
|
||||
inputName,
|
||||
values: input.values
|
||||
})
|
||||
if (input.axis_id == null || !axes.some(a => a.id != null && a.id === input.axis_id)) {
|
||||
let label = input.axis_name || `${nodeClass}: ${inputName}`;
|
||||
axes.push({
|
||||
id: input.axis_id,
|
||||
nodeID,
|
||||
nodeClass,
|
||||
selectorID: `${nodeID}-${inputName}`.replace(" ", "-"),
|
||||
label,
|
||||
inputName,
|
||||
values: input.values
|
||||
})
|
||||
}
|
||||
}
|
||||
else if (isInputLink(input)) {
|
||||
const inputNodeID = input[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user