Works for treating CLIP/model in LoraLoader as one axis but it's messy

This commit is contained in:
space-nuko 2023-06-09 20:29:13 -05:00
parent 5f9cc806f5
commit a8f3d7a872
4 changed files with 164 additions and 43 deletions

View File

@ -8,6 +8,7 @@ import traceback
import gc import gc
import time import time
import itertools import itertools
import uuid
from typing import List, Dict from typing import List, Dict
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
@ -40,52 +41,104 @@ def get_input_data_batches(input_data_all):
values together.""" values together."""
input_to_index = {} input_to_index = {}
input_to_values = {}
index_to_values = [] index_to_values = []
index_to_axis = {} input_to_axis = {}
index_to_coords = [] index_to_coords = []
# Axis ID to inherit
inherit_id = True
axis_id = None
# Sort by input name first so the order which batch inputs are applied can # Sort by input name first so the order which batch inputs are applied can
# be easily calculated (node execution order first, then alphabetical input # be easily calculated (node execution order first, then alphabetical input
# name second) # name second)
sorted_input_names = sorted(input_data_all.keys()) sorted_input_names = sorted(input_data_all.keys())
i = 0
for input_name in sorted_input_names: for input_name in sorted_input_names:
value = input_data_all[input_name] value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value: if isinstance(value, dict) and "combinatorial" in value:
if "axis_id" in value: if "axis_id" in value:
found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None) input_to_axis[input_name] = {
else: "axis_id": value["axis_id"],
found_i = None "join_axis": value.get("join_axis", False)
}
if found_i is not None: i = 0
input_to_index[input_name] = found_i
def add_index(input_name):
nonlocal i, input_data_all, input_to_index, index_to_coords
value = input_data_all[input_name]
input_to_index[input_name] = i
index_to_values.append(value["values"])
index_to_coords.append(list(range(len(value["values"]))))
ret = i
i += 1
return ret
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
if "axis_id" in value:
if axis_id is None:
axis_id = value["axis_id"]
elif axis_id != value["axis_id"]:
inherit_id = False
found_name = next((k for k, v in input_to_axis.items() if v["axis_id"] == value["axis_id"]), None)
else: else:
input_to_index[input_name] = i inherit_id = False
index_to_values.append(value["values"]) found_name = None
index_to_coords.append(list(range(len(value["values"]))))
if "axis_id" in value: if found_name is not None:
index_to_axis[i] = value["axis_id"] join = input_to_axis[found_name]["join_axis"]
i += 1 found_i = input_to_index.get(found_name)
if found_i is None:
found_i = add_index(found_name)
input_to_index[input_name] = found_i
if not join:
input_to_values[input_name] = value["values"]
else:
add_index(input_name)
if len(index_to_values) == 0: if len(index_to_values) == 0:
# No combinatorial options. # No combinatorial options.
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None) return CombinatorialBatches([{ "inputs": input_data_all }], input_to_index, index_to_values, None, None)
batches = [] batches = []
if not inherit_id or axis_id is None:
axis_id = str(uuid.uuid4())
from pprint import pp
pp(input_to_index)
pp(input_to_values)
pp(index_to_values)
indices = list(itertools.product(*index_to_coords)) indices = list(itertools.product(*index_to_coords))
combinations = list(itertools.product(*index_to_values)) combinations = list(itertools.product(*index_to_values))
for combination in combinations:
pp(indices)
for i, indices_set in enumerate(indices):
combination = combinations[i]
batch = {} batch = {}
for input_name, value in input_data_all.items(): for input_name, value in input_data_all.items():
if isinstance(value, dict) and "combinatorial" in value: if isinstance(value, dict) and "combinatorial" in value:
combination_index = input_to_index[input_name] combination_index = input_to_index[input_name]
batch[input_name] = [combination[combination_index]] index = indices_set[combination_index]
if input_name in input_to_values:
value = input_to_values[input_name][index]
else:
value = combination[combination_index]
batch[input_name] = [value]
else: else:
# already made into a list by get_input_data # already made into a list by get_input_data
batch[input_name] = value batch[input_name] = value
batches.append(batch) batches.append({
"inputs": batch,
"axis_id": axis_id
})
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations) return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
@ -103,9 +156,11 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
if input_unique_id not in outputs: if input_unique_id not in outputs:
return None return None
output_data = outputs[input_unique_id]
# This is a list of outputs for each batch of combinatorial inputs. # This is a list of outputs for each batch of combinatorial inputs.
# Without any combinatorial inputs, it's a list of length 1. # Without any combinatorial inputs, it's a list of length 1.
outputs_for_all_batches = outputs[input_unique_id] outputs_for_all_batches = output_data["batches"]
def flatten(list_of_lists): def flatten(list_of_lists):
return list(itertools.chain.from_iterable(list_of_lists)) return list(itertools.chain.from_iterable(list_of_lists))
@ -114,11 +169,38 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# Single batch, no combinatorial stuff # Single batch, no combinatorial stuff
input_data_all[x] = outputs_for_all_batches[0][output_index] input_data_all[x] = outputs_for_all_batches[0][output_index]
else: else:
from pprint import pp
print("GETINPUTDATA")
print(x)
print(input_unique_id)
# Make the outputs into a list for map-over-list use # Make the outputs into a list for map-over-list use
# (they are themselves lists so flatten them afterwards) # (they are themselves lists so flatten them afterwards)
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
input_values = { "combinatorial": True, "values": flatten(input_values) } input_values = {
"combinatorial": True,
"values": flatten(input_values),
# always treat multiple outputs from a node as belonging to
# the same grid "axis". situation this is supposed to prevent:
#
# LoraLoader outputs both a modified CLIP and MODEL. to
# ensure the outputs are enumerated combinatorially with
# others, they should be marked combinatorial.
#
# however, this does *not* mean the executor should
# enumerate every combination of CLIP and MODEL that can
# possibly be output *from the same node*. as in, the CLIP
# from one set of LoRA weights being combined with the MODEL
# from a different set of weights, as you'd never encounter
# that combination with regular use of the LoraLoader node.
#
# thus if a combinatorial set of outputs is detected, group
# them under the same axis so each of the outputs are
# updated in pairs/triplets/etc. instead of combinatorially
"axis_id": output_data["axis_id"]
}
input_data_all[x] = input_values input_data_all[x] = input_values
print("--------------------")
elif is_combinatorial_input(input_data): elif is_combinatorial_input(input_data):
if required_or_optional: if required_or_optional:
input_data_all[x] = { input_data_all[x] = {
@ -151,18 +233,27 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
st += f"list[len: {len(v)}][" st += f"list[len: {len(v)}]["
i = [] i = []
for v2 in v: for v2 in v:
i.append(v2.__class__.__name__) if isinstance(v2, (int, float, bool)):
i.append(str(v2))
else:
i.append(v2.__class__.__name__)
st += ",".join(i) + "]" st += ",".join(i) + "]"
else: else:
st += str(type(v)) if isinstance(v, (int, float, bool)):
st += str(v)
else:
st += str(type(v))
s.append(st) s.append(st)
return "( " + ", ".join(s) + " )" return "( " + ", ".join(s) + " )"
print("---------------------------------") print("---------------------------------")
from pprint import pp from pprint import pp
for batch in input_data_all_batches.batches: for batch in input_data_all_batches.batches:
print(format_dict(batch)); print(format_dict(batch["inputs"]))
# pp(input_data_all)
# pp(input_data_all_batches.batches)
print(input_data_all_batches.input_to_index)
# print(input_data_all_batches.index_to_values)
print("---------------------------------") print("---------------------------------")
return input_data_all_batches return input_data_all_batches
@ -202,11 +293,12 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
all_outputs = [] all_outputs = []
all_outputs_ui = [] all_outputs_ui = []
axis_id = None
total_batches = len(input_data_all_batches.batches) total_batches = len(input_data_all_batches.batches)
total_inner_batches = 0 total_inner_batches = 0
for batch in input_data_all_batches.batches: for batch in input_data_all_batches.batches:
total_inner_batches += max(len(x) for x in batch.values()) total_inner_batches += max(len(x) for x in batch["inputs"].values())
inner_totals = 0 inner_totals = 0
@ -226,9 +318,13 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
def cb(inner_num, inner_total): def cb(inner_num, inner_total):
send_batch_progress(inner_num) send_batch_progress(inner_num)
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb) batch_inputs = batch["inputs"]
return_values = map_node_over_list(obj, batch_inputs, obj.FUNCTION, allow_interrupt=True, callback=cb)
inner_totals += max(len(x) for x in batch.values()) if axis_id is None and "axis_id" in batch:
axis_id = batch["axis_id"]
inner_totals += max(len(x) for x in batch_inputs.values())
uis = [] uis = []
results = [] results = []
@ -280,7 +376,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
message["indices"] = input_data_all_batches.indices[batch_num] message["indices"] = input_data_all_batches.indices[batch_num]
server.send_sync("executed", message, server.client_id) server.send_sync("executed", message, server.client_id)
return all_outputs, all_outputs_ui return all_outputs, all_outputs_ui, axis_id
def format_value(x): def format_value(x):
if x is None: if x is None:
@ -330,8 +426,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
obj = class_def() obj = class_def()
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
outputs[unique_id] = output_data_from_batches outputs[unique_id] = {
"batches": output_data_from_batches,
"axis_id": output_axis_id
}
if any(output_ui_from_batches): if any(output_ui_from_batches):
outputs_ui[unique_id] = output_ui_from_batches outputs_ui[unique_id] = output_ui_from_batches
elif unique_id in outputs_ui: elif unique_id in outputs_ui:
@ -356,7 +455,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if input_data_all_batches is not None: if input_data_all_batches is not None:
d = {} d = {}
for batch in input_data_all_batches.batches: for batch in input_data_all_batches.batches:
for name, inputs in batch.items(): for name, inputs in batch["inputs"].items():
d[name] = [format_value(x) for x in inputs] d[name] = [format_value(x) for x in inputs]
input_data_formatted.append(d) input_data_formatted.append(d)
@ -416,7 +515,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
try: try:
#is_changed = class_def.IS_CHANGED(**input_data_all) #is_changed = class_def.IS_CHANGED(**input_data_all)
for batch in input_data_all_batches.batches: for batch in input_data_all_batches.batches:
if map_node_over_list(class_def, batch, "IS_CHANGED"): if map_node_over_list(class_def, batch["inputs"], "IS_CHANGED"):
is_changed = True is_changed = True
break break
prompt[unique_id]['is_changed'] = is_changed prompt[unique_id]['is_changed'] = is_changed
@ -757,7 +856,7 @@ def validate_inputs(prompt, item, validated):
input_data_all_batches = get_input_data(inputs, obj_class, unique_id) input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
#ret = obj_class.VALIDATE_INPUTS(**input_data_all) #ret = obj_class.VALIDATE_INPUTS(**input_data_all)
for batch in input_data_all_batches.batches: for batch in input_data_all_batches.batches:
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") ret = map_node_over_list(obj_class, batch["inputs"], "VALIDATE_INPUTS")
for r in ret: for r in ret:
if r != True: if r != True:
details = f"{x}" details = f"{x}"

View File

@ -23,7 +23,7 @@ app.registerExtension({
nodeType.prototype.onNodeCreated = function () { nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
this.addWidget("button", "Show Grid", "Show Grid", () => { this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => {
const grid = app.nodeGrids[this.id]; const grid = app.nodeGrids[this.id];
if (grid == null) { if (grid == null) {
console.warn("No grid to show!"); console.warn("No grid to show!");
@ -282,6 +282,14 @@ app.registerExtension({
document.body.appendChild(this._gridPanel); document.body.appendChild(this._gridPanel);
}) })
this.showGridWidget.disabled = true;
}
const onExecuted = nodeType.prototype.onExecuted;
nodeType.prototype.onExecuted = function (output) {
const r = onExecuted ? onExecuted.apply(this, arguments) : undefined;
this.showGridWidget.disabled = app.nodeGrids[this.id] == null;
} }
} }
}) })

View File

@ -246,13 +246,25 @@ app.registerExtension({
else if (inputType === "FLOAT") { else if (inputType === "FLOAT") {
values = values.map(v => parseFloat(v)) values = values.map(v => parseFloat(v))
} }
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName } widget.value = {
__inputType__: "combinatorial",
values: values,
axis_id: axisID,
axis_name: axisName,
join_axis: Boolean(axisName)
}
break; break;
case "range": case "range":
const isNumberWidget = widget.type === "number" || widget.origType === "number"; const isNumberWidget = widget.type === "number" || widget.origType === "number";
if (isNumberWidget) { if (isNumberWidget) {
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps); values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName } widget.value = {
__inputType__: "combinatorial",
values: values,
axis_id: axisID,
axis_name: axisName,
join_axis: Boolean(axisName)
}
break; break;
} }
case "single": case "single":

View File

@ -999,14 +999,16 @@ export class ComfyApp {
}); });
api.addEventListener("executed", ({ detail }) => { api.addEventListener("executed", ({ detail }) => {
this.nodeOutputs[detail.node] = detail.output; if (detail.batch_num === detail.total_batches) {
if (detail.output != null) { this.nodeOutputs[detail.node] = detail.output;
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt) if (detail.output != null) {
} this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
const node = this.graph.getNodeById(detail.node); }
if (node) { const node = this.graph.getNodeById(detail.node);
if (node.onExecuted) if (node) {
node.onExecuted(detail.output); if (node.onExecuted)
node.onExecuted(detail.output);
}
} }
if (this.batchProgress != null) { if (this.batchProgress != null) {
this.batchProgress.value = detail.batch_num this.batchProgress.value = detail.batch_num