This commit is contained in:
space-nuko 2023-06-09 17:09:30 -05:00
parent b6c5f6ae9c
commit df0c752520
3 changed files with 71 additions and 29 deletions

View File

@ -87,8 +87,39 @@ def get_input_data_batches(input_data_all):
batch[input_name] = value
batches.append(batch)
print("------------------=+++++++++++++++++")
for batch in batches:
print(format_dict(batch))
print(format_dict(input_to_index))
print(format_dict({ "v": index_to_values }))
print(index_to_coords)
print("------------------=+++++++++++++++++")
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
def format_dict(d):
s = []
for k,v in d.items():
st = f"{k}: "
if isinstance(v, list):
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
if isinstance(v2, (int, float, bool)):
i.append(str(v2))
else:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
if isinstance(v, (int, float, bool)):
st += str(v)
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
"""Given input data from the prompt, returns a list of input data dicts for
each combinatorial batch."""
@ -117,7 +148,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# Make the outputs into a list for map-over-list use
# (they are themselves lists so flatten them afterwards)
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
input_values = { "combinatorial": True, "values": flatten(input_values) }
print("COMB")
print(str(input_unique_id))
print(str(output_index))
print(format_dict({ "values": input_values }))
input_values = {
"combinatorial": True,
"values": flatten(input_values),
"axis_id": prompt[input_unique_id].get("axis_id")
}
input_data_all[x] = input_values
elif is_combinatorial_input(input_data):
if required_or_optional:
@ -143,22 +182,6 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all_batches = get_input_data_batches(input_data_all)
def format_dict(d):
s = []
for k,v in d.items():
st = f"{k}: "
if isinstance(v, list):
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
print("---------------------------------")
from pprint import pp
for batch in input_data_all_batches.batches:
@ -274,7 +297,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
"output": outputs_ui_to_send,
"prompt_id": prompt_id,
"batch_num": inner_totals,
"total_batches": total_inner_batches
"total_batches": total_inner_batches,
}
if input_data_all_batches.indices:
message["indices"] = input_data_all_batches.indices[batch_num]
@ -411,7 +434,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]:
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs)
input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt)
if input_data_all_batches is not None:
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
@ -754,7 +777,7 @@ def validate_inputs(prompt, item, validated):
inputs[x] = r[1]
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
input_data_all_batches = get_input_data(inputs, obj_class, unique_id, {}, prompt)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
for batch in input_data_all_batches.batches:
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")

View File

@ -23,7 +23,7 @@ app.registerExtension({
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
this.addWidget("button", "Show Grid", "Show Grid", () => {
this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => {
const grid = app.nodeGrids[this.id];
if (grid == null) {
console.warn("No grid to show!");
@ -282,6 +282,14 @@ app.registerExtension({
document.body.appendChild(this._gridPanel);
})
this.showGridWidget.disabled = true;
}
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function (output) {
const r = onExecuted ? onExecuted.apply(this, arguments) : undefined;
this.showGridWidget.disabled = app.nodeGrids[this.id] == null;
}
}
})

View File

@ -999,14 +999,16 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output;
if (detail.output != null) {
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);
if (detail.batch_num === detail.total_batches) {
this.nodeOutputs[detail.node] = detail.output;
if (detail.output != null) {
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);
}
}
if (this.batchProgress != null) {
this.batchProgress.value = detail.batch_num
@ -1473,6 +1475,7 @@ export class ComfyApp {
const inputs = {};
const widgets = node.widgets;
let axis_id = null;
// Store all widget values
if (widgets) {
@ -1484,6 +1487,13 @@ export class ComfyApp {
if (typeof widgetValue === "object" && widgetValue.__inputType__) {
totalCombinatorialNodes += 1;
executionFactor *= widgetValue.values.length;
if (widgetValue.axis_id != null) {
if (axis_id != null && axis_id != widgetValue.axis_id) {
throw new RuntimeError("Each node's outputs can only belong to one axis at a time");
}
axis_id = widgetValue.axis_id;
}
}
totalExecuted += executionFactor;
}
@ -1513,6 +1523,7 @@ export class ComfyApp {
output[String(node.id)] = {
inputs,
class_type: node.comfyClass,
axis_id,
};
}