Works for treating CLIP/model in LoraLoader as one axis but it's messy

This commit is contained in:
space-nuko 2023-06-09 20:29:13 -05:00
parent 5f9cc806f5
commit a8f3d7a872
4 changed files with 164 additions and 43 deletions

View File

@ -8,6 +8,7 @@ import traceback
import gc
import time
import itertools
import uuid
from typing import List, Dict
import dataclasses
from dataclasses import dataclass
@ -40,52 +41,104 @@ def get_input_data_batches(input_data_all):
values together."""
input_to_index = {}
input_to_values = {}
index_to_values = []
index_to_axis = {}
input_to_axis = {}
index_to_coords = []
# Axis ID to inherit
inherit_id = True
axis_id = None
# Sort by input name first so the order which batch inputs are applied can
# be easily calculated (node execution order first, then alphabetical input
# name second)
sorted_input_names = sorted(input_data_all.keys())
i = 0
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
if "axis_id" in value:
found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
else:
found_i = None
input_to_axis[input_name] = {
"axis_id": value["axis_id"],
"join_axis": value.get("join_axis", False)
}
if found_i is not None:
input_to_index[input_name] = found_i
i = 0
def add_index(input_name):
nonlocal i, input_data_all, input_to_index, index_to_coords
value = input_data_all[input_name]
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
ret = i
i += 1
return ret
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
if "axis_id" in value:
if axis_id is None:
axis_id = value["axis_id"]
elif axis_id != value["axis_id"]:
inherit_id = False
found_name = next((k for k, v in input_to_axis.items() if v["axis_id"] == value["axis_id"]), None)
else:
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
if "axis_id" in value:
index_to_axis[i] = value["axis_id"]
i += 1
inherit_id = False
found_name = None
if found_name is not None:
join = input_to_axis[found_name]["join_axis"]
found_i = input_to_index.get(found_name)
if found_i is None:
found_i = add_index(found_name)
input_to_index[input_name] = found_i
if not join:
input_to_values[input_name] = value["values"]
else:
add_index(input_name)
if len(index_to_values) == 0:
# No combinatorial options.
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
return CombinatorialBatches([{ "inputs": input_data_all }], input_to_index, index_to_values, None, None)
batches = []
if not inherit_id or axis_id is None:
axis_id = str(uuid.uuid4())
from pprint import pp
pp(input_to_index)
pp(input_to_values)
pp(index_to_values)
indices = list(itertools.product(*index_to_coords))
combinations = list(itertools.product(*index_to_values))
for combination in combinations:
pp(indices)
for i, indices_set in enumerate(indices):
combination = combinations[i]
batch = {}
for input_name, value in input_data_all.items():
if isinstance(value, dict) and "combinatorial" in value:
combination_index = input_to_index[input_name]
batch[input_name] = [combination[combination_index]]
index = indices_set[combination_index]
if input_name in input_to_values:
value = input_to_values[input_name][index]
else:
value = combination[combination_index]
batch[input_name] = [value]
else:
# already made into a list by get_input_data
batch[input_name] = value
batches.append(batch)
batches.append({
"inputs": batch,
"axis_id": axis_id
})
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
@ -103,9 +156,11 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
if input_unique_id not in outputs:
return None
output_data = outputs[input_unique_id]
# This is a list of outputs for each batch of combinatorial inputs.
# Without any combinatorial inputs, it's a list of length 1.
outputs_for_all_batches = outputs[input_unique_id]
outputs_for_all_batches = output_data["batches"]
def flatten(list_of_lists):
return list(itertools.chain.from_iterable(list_of_lists))
@ -114,11 +169,38 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# Single batch, no combinatorial stuff
input_data_all[x] = outputs_for_all_batches[0][output_index]
else:
from pprint import pp
print("GETINPUTDATA")
print(x)
print(input_unique_id)
# Make the outputs into a list for map-over-list use
# (they are themselves lists so flatten them afterwards)
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
input_values = { "combinatorial": True, "values": flatten(input_values) }
input_values = {
"combinatorial": True,
"values": flatten(input_values),
# always treat multiple outputs from a node as belonging to
# the same grid "axis". situation this is supposed to prevent:
#
# LoraLoader outputs both a modified CLIP and MODEL. to
# ensure the outputs are enumerated combinatorially with
# others, they should be marked combinatorial.
#
# however, this does *not* mean the executor should
# enumerate every combination of CLIP and MODEL that can
# possibly be output *from the same node*. as in, the CLIP
# from one set of LoRA weights being combined with the MODEL
# from a different set of weights, as you'd never encounter
# that combination with regular use of the LoraLoader node.
#
# thus if a combinatorial set of outputs is detected, group
# them under the same axis so each of the outputs are
# updated in pairs/triplets/etc. instead of combinatorially
"axis_id": output_data["axis_id"]
}
input_data_all[x] = input_values
print("--------------------")
elif is_combinatorial_input(input_data):
if required_or_optional:
input_data_all[x] = {
@ -151,18 +233,27 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
st += f"list[len: {len(v)}]["
i = []
for v2 in v:
i.append(v2.__class__.__name__)
if isinstance(v2, (int, float, bool)):
i.append(str(v2))
else:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]"
else:
st += str(type(v))
if isinstance(v, (int, float, bool)):
st += str(v)
else:
st += str(type(v))
s.append(st)
return "( " + ", ".join(s) + " )"
print("---------------------------------")
from pprint import pp
for batch in input_data_all_batches.batches:
print(format_dict(batch));
print(format_dict(batch["inputs"]))
# pp(input_data_all)
# pp(input_data_all_batches.batches)
print(input_data_all_batches.input_to_index)
# print(input_data_all_batches.index_to_values)
print("---------------------------------")
return input_data_all_batches
@ -202,11 +293,12 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
all_outputs = []
all_outputs_ui = []
axis_id = None
total_batches = len(input_data_all_batches.batches)
total_inner_batches = 0
for batch in input_data_all_batches.batches:
total_inner_batches += max(len(x) for x in batch.values())
total_inner_batches += max(len(x) for x in batch["inputs"].values())
inner_totals = 0
@ -226,9 +318,13 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
def cb(inner_num, inner_total):
send_batch_progress(inner_num)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
batch_inputs = batch["inputs"]
return_values = map_node_over_list(obj, batch_inputs, obj.FUNCTION, allow_interrupt=True, callback=cb)
inner_totals += max(len(x) for x in batch.values())
if axis_id is None and "axis_id" in batch:
axis_id = batch["axis_id"]
inner_totals += max(len(x) for x in batch_inputs.values())
uis = []
results = []
@ -280,7 +376,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
message["indices"] = input_data_all_batches.indices[batch_num]
server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui
return all_outputs, all_outputs_ui, axis_id
def format_value(x):
if x is None:
@ -330,8 +426,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
obj = class_def()
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
outputs[unique_id] = output_data_from_batches
output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
outputs[unique_id] = {
"batches": output_data_from_batches,
"axis_id": output_axis_id
}
if any(output_ui_from_batches):
outputs_ui[unique_id] = output_ui_from_batches
elif unique_id in outputs_ui:
@ -356,7 +455,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if input_data_all_batches is not None:
d = {}
for batch in input_data_all_batches.batches:
for name, inputs in batch.items():
for name, inputs in batch["inputs"].items():
d[name] = [format_value(x) for x in inputs]
input_data_formatted.append(d)
@ -416,7 +515,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
try:
#is_changed = class_def.IS_CHANGED(**input_data_all)
for batch in input_data_all_batches.batches:
if map_node_over_list(class_def, batch, "IS_CHANGED"):
if map_node_over_list(class_def, batch["inputs"], "IS_CHANGED"):
is_changed = True
break
prompt[unique_id]['is_changed'] = is_changed
@ -757,7 +856,7 @@ def validate_inputs(prompt, item, validated):
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
for batch in input_data_all_batches.batches:
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
ret = map_node_over_list(obj_class, batch["inputs"], "VALIDATE_INPUTS")
for r in ret:
if r != True:
details = f"{x}"

View File

@ -23,7 +23,7 @@ app.registerExtension({
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
this.addWidget("button", "Show Grid", "Show Grid", () => {
this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => {
const grid = app.nodeGrids[this.id];
if (grid == null) {
console.warn("No grid to show!");
@ -282,6 +282,14 @@ app.registerExtension({
document.body.appendChild(this._gridPanel);
})
this.showGridWidget.disabled = true;
}
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function (output) {
const r = onExecuted ? onExecuted.apply(this, arguments) : undefined;
this.showGridWidget.disabled = app.nodeGrids[this.id] == null;
}
}
})

View File

@ -246,13 +246,25 @@ app.registerExtension({
else if (inputType === "FLOAT") {
values = values.map(v => parseFloat(v))
}
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
widget.value = {
__inputType__: "combinatorial",
values: values,
axis_id: axisID,
axis_name: axisName,
join_axis: Boolean(axisName)
}
break;
case "range":
const isNumberWidget = widget.type === "number" || widget.origType === "number";
if (isNumberWidget) {
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
widget.value = {
__inputType__: "combinatorial",
values: values,
axis_id: axisID,
axis_name: axisName,
join_axis: Boolean(axisName)
}
break;
}
case "single":

View File

@ -999,14 +999,16 @@ export class ComfyApp {
});
api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output;
if (detail.output != null) {
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);
if (detail.batch_num === detail.total_batches) {
this.nodeOutputs[detail.node] = detail.output;
if (detail.output != null) {
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
}
const node = this.graph.getNodeById(detail.node);
if (node) {
if (node.onExecuted)
node.onExecuted(detail.output);
}
}
if (this.batchProgress != null) {
this.batchProgress.value = detail.batch_num