mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Works for treating CLIP/model in LoraLoader as one axis but it's messy
This commit is contained in:
parent
5f9cc806f5
commit
a8f3d7a872
163
execution.py
163
execution.py
@ -8,6 +8,7 @@ import traceback
|
||||
import gc
|
||||
import time
|
||||
import itertools
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
@ -40,52 +41,104 @@ def get_input_data_batches(input_data_all):
|
||||
values together."""
|
||||
|
||||
input_to_index = {}
|
||||
input_to_values = {}
|
||||
index_to_values = []
|
||||
index_to_axis = {}
|
||||
input_to_axis = {}
|
||||
index_to_coords = []
|
||||
|
||||
# Axis ID to inherit
|
||||
inherit_id = True
|
||||
axis_id = None
|
||||
|
||||
# Sort by input name first so the order which batch inputs are applied can
|
||||
# be easily calculated (node execution order first, then alphabetical input
|
||||
# name second)
|
||||
sorted_input_names = sorted(input_data_all.keys())
|
||||
|
||||
i = 0
|
||||
for input_name in sorted_input_names:
|
||||
value = input_data_all[input_name]
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
if "axis_id" in value:
|
||||
found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
|
||||
else:
|
||||
found_i = None
|
||||
input_to_axis[input_name] = {
|
||||
"axis_id": value["axis_id"],
|
||||
"join_axis": value.get("join_axis", False)
|
||||
}
|
||||
|
||||
if found_i is not None:
|
||||
input_to_index[input_name] = found_i
|
||||
i = 0
|
||||
|
||||
def add_index(input_name):
|
||||
nonlocal i, input_data_all, input_to_index, index_to_coords
|
||||
value = input_data_all[input_name]
|
||||
input_to_index[input_name] = i
|
||||
index_to_values.append(value["values"])
|
||||
index_to_coords.append(list(range(len(value["values"]))))
|
||||
ret = i
|
||||
i += 1
|
||||
return ret
|
||||
|
||||
for input_name in sorted_input_names:
|
||||
value = input_data_all[input_name]
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
if "axis_id" in value:
|
||||
if axis_id is None:
|
||||
axis_id = value["axis_id"]
|
||||
elif axis_id != value["axis_id"]:
|
||||
inherit_id = False
|
||||
|
||||
found_name = next((k for k, v in input_to_axis.items() if v["axis_id"] == value["axis_id"]), None)
|
||||
else:
|
||||
input_to_index[input_name] = i
|
||||
index_to_values.append(value["values"])
|
||||
index_to_coords.append(list(range(len(value["values"]))))
|
||||
if "axis_id" in value:
|
||||
index_to_axis[i] = value["axis_id"]
|
||||
i += 1
|
||||
inherit_id = False
|
||||
found_name = None
|
||||
|
||||
if found_name is not None:
|
||||
join = input_to_axis[found_name]["join_axis"]
|
||||
found_i = input_to_index.get(found_name)
|
||||
if found_i is None:
|
||||
found_i = add_index(found_name)
|
||||
input_to_index[input_name] = found_i
|
||||
if not join:
|
||||
input_to_values[input_name] = value["values"]
|
||||
else:
|
||||
add_index(input_name)
|
||||
|
||||
if len(index_to_values) == 0:
|
||||
# No combinatorial options.
|
||||
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
|
||||
return CombinatorialBatches([{ "inputs": input_data_all }], input_to_index, index_to_values, None, None)
|
||||
|
||||
batches = []
|
||||
|
||||
if not inherit_id or axis_id is None:
|
||||
axis_id = str(uuid.uuid4())
|
||||
|
||||
from pprint import pp
|
||||
pp(input_to_index)
|
||||
pp(input_to_values)
|
||||
pp(index_to_values)
|
||||
|
||||
indices = list(itertools.product(*index_to_coords))
|
||||
combinations = list(itertools.product(*index_to_values))
|
||||
for combination in combinations:
|
||||
|
||||
pp(indices)
|
||||
|
||||
for i, indices_set in enumerate(indices):
|
||||
combination = combinations[i]
|
||||
batch = {}
|
||||
for input_name, value in input_data_all.items():
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
combination_index = input_to_index[input_name]
|
||||
batch[input_name] = [combination[combination_index]]
|
||||
index = indices_set[combination_index]
|
||||
if input_name in input_to_values:
|
||||
value = input_to_values[input_name][index]
|
||||
else:
|
||||
value = combination[combination_index]
|
||||
batch[input_name] = [value]
|
||||
else:
|
||||
# already made into a list by get_input_data
|
||||
batch[input_name] = value
|
||||
batches.append(batch)
|
||||
batches.append({
|
||||
"inputs": batch,
|
||||
"axis_id": axis_id
|
||||
})
|
||||
|
||||
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
|
||||
|
||||
@ -103,9 +156,11 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
if input_unique_id not in outputs:
|
||||
return None
|
||||
|
||||
output_data = outputs[input_unique_id]
|
||||
|
||||
# This is a list of outputs for each batch of combinatorial inputs.
|
||||
# Without any combinatorial inputs, it's a list of length 1.
|
||||
outputs_for_all_batches = outputs[input_unique_id]
|
||||
outputs_for_all_batches = output_data["batches"]
|
||||
|
||||
def flatten(list_of_lists):
|
||||
return list(itertools.chain.from_iterable(list_of_lists))
|
||||
@ -114,11 +169,38 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
# Single batch, no combinatorial stuff
|
||||
input_data_all[x] = outputs_for_all_batches[0][output_index]
|
||||
else:
|
||||
from pprint import pp
|
||||
print("GETINPUTDATA")
|
||||
print(x)
|
||||
print(input_unique_id)
|
||||
# Make the outputs into a list for map-over-list use
|
||||
# (they are themselves lists so flatten them afterwards)
|
||||
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
|
||||
input_values = { "combinatorial": True, "values": flatten(input_values) }
|
||||
input_values = {
|
||||
"combinatorial": True,
|
||||
"values": flatten(input_values),
|
||||
|
||||
# always treat multiple outputs from a node as belonging to
|
||||
# the same grid "axis". situation this is supposed to prevent:
|
||||
#
|
||||
# LoraLoader outputs both a modified CLIP and MODEL. to
|
||||
# ensure the outputs are enumerated combinatorially with
|
||||
# others, they should be marked combinatorial.
|
||||
#
|
||||
# however, this does *not* mean the executor should
|
||||
# enumerate every combination of CLIP and MODEL that can
|
||||
# possibly be output *from the same node*. as in, the CLIP
|
||||
# from one set of LoRA weights being combined with the MODEL
|
||||
# from a different set of weights, as you'd never encounter
|
||||
# that combination with regular use of the LoraLoader node.
|
||||
#
|
||||
# thus if a combinatorial set of outputs is detected, group
|
||||
# them under the same axis so each of the outputs are
|
||||
# updated in pairs/triplets/etc. instead of combinatorially
|
||||
"axis_id": output_data["axis_id"]
|
||||
}
|
||||
input_data_all[x] = input_values
|
||||
print("--------------------")
|
||||
elif is_combinatorial_input(input_data):
|
||||
if required_or_optional:
|
||||
input_data_all[x] = {
|
||||
@ -151,18 +233,27 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
st += f"list[len: {len(v)}]["
|
||||
i = []
|
||||
for v2 in v:
|
||||
i.append(v2.__class__.__name__)
|
||||
if isinstance(v2, (int, float, bool)):
|
||||
i.append(str(v2))
|
||||
else:
|
||||
i.append(v2.__class__.__name__)
|
||||
st += ",".join(i) + "]"
|
||||
else:
|
||||
st += str(type(v))
|
||||
if isinstance(v, (int, float, bool)):
|
||||
st += str(v)
|
||||
else:
|
||||
st += str(type(v))
|
||||
s.append(st)
|
||||
return "( " + ", ".join(s) + " )"
|
||||
|
||||
|
||||
print("---------------------------------")
|
||||
from pprint import pp
|
||||
for batch in input_data_all_batches.batches:
|
||||
print(format_dict(batch));
|
||||
print(format_dict(batch["inputs"]))
|
||||
# pp(input_data_all)
|
||||
# pp(input_data_all_batches.batches)
|
||||
print(input_data_all_batches.input_to_index)
|
||||
# print(input_data_all_batches.index_to_values)
|
||||
print("---------------------------------")
|
||||
|
||||
return input_data_all_batches
|
||||
@ -202,11 +293,12 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
|
||||
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
all_outputs = []
|
||||
all_outputs_ui = []
|
||||
axis_id = None
|
||||
total_batches = len(input_data_all_batches.batches)
|
||||
|
||||
total_inner_batches = 0
|
||||
for batch in input_data_all_batches.batches:
|
||||
total_inner_batches += max(len(x) for x in batch.values())
|
||||
total_inner_batches += max(len(x) for x in batch["inputs"].values())
|
||||
|
||||
inner_totals = 0
|
||||
|
||||
@ -226,9 +318,13 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
def cb(inner_num, inner_total):
|
||||
send_batch_progress(inner_num)
|
||||
|
||||
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
|
||||
batch_inputs = batch["inputs"]
|
||||
return_values = map_node_over_list(obj, batch_inputs, obj.FUNCTION, allow_interrupt=True, callback=cb)
|
||||
|
||||
inner_totals += max(len(x) for x in batch.values())
|
||||
if axis_id is None and "axis_id" in batch:
|
||||
axis_id = batch["axis_id"]
|
||||
|
||||
inner_totals += max(len(x) for x in batch_inputs.values())
|
||||
|
||||
uis = []
|
||||
results = []
|
||||
@ -280,7 +376,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||
message["indices"] = input_data_all_batches.indices[batch_num]
|
||||
server.send_sync("executed", message, server.client_id)
|
||||
|
||||
return all_outputs, all_outputs_ui
|
||||
return all_outputs, all_outputs_ui, axis_id
|
||||
|
||||
def format_value(x):
|
||||
if x is None:
|
||||
@ -330,8 +426,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
|
||||
obj = class_def()
|
||||
|
||||
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
||||
outputs[unique_id] = output_data_from_batches
|
||||
output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
||||
outputs[unique_id] = {
|
||||
"batches": output_data_from_batches,
|
||||
"axis_id": output_axis_id
|
||||
}
|
||||
if any(output_ui_from_batches):
|
||||
outputs_ui[unique_id] = output_ui_from_batches
|
||||
elif unique_id in outputs_ui:
|
||||
@ -356,7 +455,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
if input_data_all_batches is not None:
|
||||
d = {}
|
||||
for batch in input_data_all_batches.batches:
|
||||
for name, inputs in batch.items():
|
||||
for name, inputs in batch["inputs"].items():
|
||||
d[name] = [format_value(x) for x in inputs]
|
||||
input_data_formatted.append(d)
|
||||
|
||||
@ -416,7 +515,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
try:
|
||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
for batch in input_data_all_batches.batches:
|
||||
if map_node_over_list(class_def, batch, "IS_CHANGED"):
|
||||
if map_node_over_list(class_def, batch["inputs"], "IS_CHANGED"):
|
||||
is_changed = True
|
||||
break
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
@ -757,7 +856,7 @@ def validate_inputs(prompt, item, validated):
|
||||
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
for batch in input_data_all_batches.batches:
|
||||
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
|
||||
ret = map_node_over_list(obj_class, batch["inputs"], "VALIDATE_INPUTS")
|
||||
for r in ret:
|
||||
if r != True:
|
||||
details = f"{x}"
|
||||
|
||||
@ -23,7 +23,7 @@ app.registerExtension({
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
||||
|
||||
this.addWidget("button", "Show Grid", "Show Grid", () => {
|
||||
this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => {
|
||||
const grid = app.nodeGrids[this.id];
|
||||
if (grid == null) {
|
||||
console.warn("No grid to show!");
|
||||
@ -282,6 +282,14 @@ app.registerExtension({
|
||||
|
||||
document.body.appendChild(this._gridPanel);
|
||||
})
|
||||
|
||||
this.showGridWidget.disabled = true;
|
||||
}
|
||||
|
||||
const onExecuted = nodeType.prototype.onExecuted;
|
||||
nodeType.prototype.onExecuted = function (output) {
|
||||
const r = onExecuted ? onExecuted.apply(this, arguments) : undefined;
|
||||
this.showGridWidget.disabled = app.nodeGrids[this.id] == null;
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@ -246,13 +246,25 @@ app.registerExtension({
|
||||
else if (inputType === "FLOAT") {
|
||||
values = values.map(v => parseFloat(v))
|
||||
}
|
||||
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
|
||||
widget.value = {
|
||||
__inputType__: "combinatorial",
|
||||
values: values,
|
||||
axis_id: axisID,
|
||||
axis_name: axisName,
|
||||
join_axis: Boolean(axisName)
|
||||
}
|
||||
break;
|
||||
case "range":
|
||||
const isNumberWidget = widget.type === "number" || widget.origType === "number";
|
||||
if (isNumberWidget) {
|
||||
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
|
||||
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
|
||||
widget.value = {
|
||||
__inputType__: "combinatorial",
|
||||
values: values,
|
||||
axis_id: axisID,
|
||||
axis_name: axisName,
|
||||
join_axis: Boolean(axisName)
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "single":
|
||||
|
||||
@ -999,14 +999,16 @@ export class ComfyApp {
|
||||
});
|
||||
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
if (detail.output != null) {
|
||||
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
if (detail.batch_num === detail.total_batches) {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
if (detail.output != null) {
|
||||
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
}
|
||||
}
|
||||
if (this.batchProgress != null) {
|
||||
this.batchProgress.value = detail.batch_num
|
||||
|
||||
Loading…
Reference in New Issue
Block a user