mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 05:10:18 +08:00
Works for treating CLIP/model in LoraLoader as one axis but it's messy
This commit is contained in:
parent
5f9cc806f5
commit
a8f3d7a872
163
execution.py
163
execution.py
@ -8,6 +8,7 @@ import traceback
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import itertools
|
import itertools
|
||||||
|
import uuid
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -40,52 +41,104 @@ def get_input_data_batches(input_data_all):
|
|||||||
values together."""
|
values together."""
|
||||||
|
|
||||||
input_to_index = {}
|
input_to_index = {}
|
||||||
|
input_to_values = {}
|
||||||
index_to_values = []
|
index_to_values = []
|
||||||
index_to_axis = {}
|
input_to_axis = {}
|
||||||
index_to_coords = []
|
index_to_coords = []
|
||||||
|
|
||||||
|
# Axis ID to inherit
|
||||||
|
inherit_id = True
|
||||||
|
axis_id = None
|
||||||
|
|
||||||
# Sort by input name first so the order which batch inputs are applied can
|
# Sort by input name first so the order which batch inputs are applied can
|
||||||
# be easily calculated (node execution order first, then alphabetical input
|
# be easily calculated (node execution order first, then alphabetical input
|
||||||
# name second)
|
# name second)
|
||||||
sorted_input_names = sorted(input_data_all.keys())
|
sorted_input_names = sorted(input_data_all.keys())
|
||||||
|
|
||||||
i = 0
|
|
||||||
for input_name in sorted_input_names:
|
for input_name in sorted_input_names:
|
||||||
value = input_data_all[input_name]
|
value = input_data_all[input_name]
|
||||||
if isinstance(value, dict) and "combinatorial" in value:
|
if isinstance(value, dict) and "combinatorial" in value:
|
||||||
if "axis_id" in value:
|
if "axis_id" in value:
|
||||||
found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None)
|
input_to_axis[input_name] = {
|
||||||
else:
|
"axis_id": value["axis_id"],
|
||||||
found_i = None
|
"join_axis": value.get("join_axis", False)
|
||||||
|
}
|
||||||
|
|
||||||
if found_i is not None:
|
i = 0
|
||||||
input_to_index[input_name] = found_i
|
|
||||||
|
def add_index(input_name):
|
||||||
|
nonlocal i, input_data_all, input_to_index, index_to_coords
|
||||||
|
value = input_data_all[input_name]
|
||||||
|
input_to_index[input_name] = i
|
||||||
|
index_to_values.append(value["values"])
|
||||||
|
index_to_coords.append(list(range(len(value["values"]))))
|
||||||
|
ret = i
|
||||||
|
i += 1
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for input_name in sorted_input_names:
|
||||||
|
value = input_data_all[input_name]
|
||||||
|
if isinstance(value, dict) and "combinatorial" in value:
|
||||||
|
if "axis_id" in value:
|
||||||
|
if axis_id is None:
|
||||||
|
axis_id = value["axis_id"]
|
||||||
|
elif axis_id != value["axis_id"]:
|
||||||
|
inherit_id = False
|
||||||
|
|
||||||
|
found_name = next((k for k, v in input_to_axis.items() if v["axis_id"] == value["axis_id"]), None)
|
||||||
else:
|
else:
|
||||||
input_to_index[input_name] = i
|
inherit_id = False
|
||||||
index_to_values.append(value["values"])
|
found_name = None
|
||||||
index_to_coords.append(list(range(len(value["values"]))))
|
|
||||||
if "axis_id" in value:
|
if found_name is not None:
|
||||||
index_to_axis[i] = value["axis_id"]
|
join = input_to_axis[found_name]["join_axis"]
|
||||||
i += 1
|
found_i = input_to_index.get(found_name)
|
||||||
|
if found_i is None:
|
||||||
|
found_i = add_index(found_name)
|
||||||
|
input_to_index[input_name] = found_i
|
||||||
|
if not join:
|
||||||
|
input_to_values[input_name] = value["values"]
|
||||||
|
else:
|
||||||
|
add_index(input_name)
|
||||||
|
|
||||||
if len(index_to_values) == 0:
|
if len(index_to_values) == 0:
|
||||||
# No combinatorial options.
|
# No combinatorial options.
|
||||||
return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None)
|
return CombinatorialBatches([{ "inputs": input_data_all }], input_to_index, index_to_values, None, None)
|
||||||
|
|
||||||
batches = []
|
batches = []
|
||||||
|
|
||||||
|
if not inherit_id or axis_id is None:
|
||||||
|
axis_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
from pprint import pp
|
||||||
|
pp(input_to_index)
|
||||||
|
pp(input_to_values)
|
||||||
|
pp(index_to_values)
|
||||||
|
|
||||||
indices = list(itertools.product(*index_to_coords))
|
indices = list(itertools.product(*index_to_coords))
|
||||||
combinations = list(itertools.product(*index_to_values))
|
combinations = list(itertools.product(*index_to_values))
|
||||||
for combination in combinations:
|
|
||||||
|
pp(indices)
|
||||||
|
|
||||||
|
for i, indices_set in enumerate(indices):
|
||||||
|
combination = combinations[i]
|
||||||
batch = {}
|
batch = {}
|
||||||
for input_name, value in input_data_all.items():
|
for input_name, value in input_data_all.items():
|
||||||
if isinstance(value, dict) and "combinatorial" in value:
|
if isinstance(value, dict) and "combinatorial" in value:
|
||||||
combination_index = input_to_index[input_name]
|
combination_index = input_to_index[input_name]
|
||||||
batch[input_name] = [combination[combination_index]]
|
index = indices_set[combination_index]
|
||||||
|
if input_name in input_to_values:
|
||||||
|
value = input_to_values[input_name][index]
|
||||||
|
else:
|
||||||
|
value = combination[combination_index]
|
||||||
|
batch[input_name] = [value]
|
||||||
else:
|
else:
|
||||||
# already made into a list by get_input_data
|
# already made into a list by get_input_data
|
||||||
batch[input_name] = value
|
batch[input_name] = value
|
||||||
batches.append(batch)
|
batches.append({
|
||||||
|
"inputs": batch,
|
||||||
|
"axis_id": axis_id
|
||||||
|
})
|
||||||
|
|
||||||
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
|
return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations)
|
||||||
|
|
||||||
@ -103,9 +156,11 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
if input_unique_id not in outputs:
|
if input_unique_id not in outputs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
output_data = outputs[input_unique_id]
|
||||||
|
|
||||||
# This is a list of outputs for each batch of combinatorial inputs.
|
# This is a list of outputs for each batch of combinatorial inputs.
|
||||||
# Without any combinatorial inputs, it's a list of length 1.
|
# Without any combinatorial inputs, it's a list of length 1.
|
||||||
outputs_for_all_batches = outputs[input_unique_id]
|
outputs_for_all_batches = output_data["batches"]
|
||||||
|
|
||||||
def flatten(list_of_lists):
|
def flatten(list_of_lists):
|
||||||
return list(itertools.chain.from_iterable(list_of_lists))
|
return list(itertools.chain.from_iterable(list_of_lists))
|
||||||
@ -114,11 +169,38 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
# Single batch, no combinatorial stuff
|
# Single batch, no combinatorial stuff
|
||||||
input_data_all[x] = outputs_for_all_batches[0][output_index]
|
input_data_all[x] = outputs_for_all_batches[0][output_index]
|
||||||
else:
|
else:
|
||||||
|
from pprint import pp
|
||||||
|
print("GETINPUTDATA")
|
||||||
|
print(x)
|
||||||
|
print(input_unique_id)
|
||||||
# Make the outputs into a list for map-over-list use
|
# Make the outputs into a list for map-over-list use
|
||||||
# (they are themselves lists so flatten them afterwards)
|
# (they are themselves lists so flatten them afterwards)
|
||||||
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
|
input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches]
|
||||||
input_values = { "combinatorial": True, "values": flatten(input_values) }
|
input_values = {
|
||||||
|
"combinatorial": True,
|
||||||
|
"values": flatten(input_values),
|
||||||
|
|
||||||
|
# always treat multiple outputs from a node as belonging to
|
||||||
|
# the same grid "axis". situation this is supposed to prevent:
|
||||||
|
#
|
||||||
|
# LoraLoader outputs both a modified CLIP and MODEL. to
|
||||||
|
# ensure the outputs are enumerated combinatorially with
|
||||||
|
# others, they should be marked combinatorial.
|
||||||
|
#
|
||||||
|
# however, this does *not* mean the executor should
|
||||||
|
# enumerate every combination of CLIP and MODEL that can
|
||||||
|
# possibly be output *from the same node*. as in, the CLIP
|
||||||
|
# from one set of LoRA weights being combined with the MODEL
|
||||||
|
# from a different set of weights, as you'd never encounter
|
||||||
|
# that combination with regular use of the LoraLoader node.
|
||||||
|
#
|
||||||
|
# thus if a combinatorial set of outputs is detected, group
|
||||||
|
# them under the same axis so each of the outputs are
|
||||||
|
# updated in pairs/triplets/etc. instead of combinatorially
|
||||||
|
"axis_id": output_data["axis_id"]
|
||||||
|
}
|
||||||
input_data_all[x] = input_values
|
input_data_all[x] = input_values
|
||||||
|
print("--------------------")
|
||||||
elif is_combinatorial_input(input_data):
|
elif is_combinatorial_input(input_data):
|
||||||
if required_or_optional:
|
if required_or_optional:
|
||||||
input_data_all[x] = {
|
input_data_all[x] = {
|
||||||
@ -151,18 +233,27 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
st += f"list[len: {len(v)}]["
|
st += f"list[len: {len(v)}]["
|
||||||
i = []
|
i = []
|
||||||
for v2 in v:
|
for v2 in v:
|
||||||
i.append(v2.__class__.__name__)
|
if isinstance(v2, (int, float, bool)):
|
||||||
|
i.append(str(v2))
|
||||||
|
else:
|
||||||
|
i.append(v2.__class__.__name__)
|
||||||
st += ",".join(i) + "]"
|
st += ",".join(i) + "]"
|
||||||
else:
|
else:
|
||||||
st += str(type(v))
|
if isinstance(v, (int, float, bool)):
|
||||||
|
st += str(v)
|
||||||
|
else:
|
||||||
|
st += str(type(v))
|
||||||
s.append(st)
|
s.append(st)
|
||||||
return "( " + ", ".join(s) + " )"
|
return "( " + ", ".join(s) + " )"
|
||||||
|
|
||||||
|
|
||||||
print("---------------------------------")
|
print("---------------------------------")
|
||||||
from pprint import pp
|
from pprint import pp
|
||||||
for batch in input_data_all_batches.batches:
|
for batch in input_data_all_batches.batches:
|
||||||
print(format_dict(batch));
|
print(format_dict(batch["inputs"]))
|
||||||
|
# pp(input_data_all)
|
||||||
|
# pp(input_data_all_batches.batches)
|
||||||
|
print(input_data_all_batches.input_to_index)
|
||||||
|
# print(input_data_all_batches.index_to_values)
|
||||||
print("---------------------------------")
|
print("---------------------------------")
|
||||||
|
|
||||||
return input_data_all_batches
|
return input_data_all_batches
|
||||||
@ -202,11 +293,12 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac
|
|||||||
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
all_outputs_ui = []
|
all_outputs_ui = []
|
||||||
|
axis_id = None
|
||||||
total_batches = len(input_data_all_batches.batches)
|
total_batches = len(input_data_all_batches.batches)
|
||||||
|
|
||||||
total_inner_batches = 0
|
total_inner_batches = 0
|
||||||
for batch in input_data_all_batches.batches:
|
for batch in input_data_all_batches.batches:
|
||||||
total_inner_batches += max(len(x) for x in batch.values())
|
total_inner_batches += max(len(x) for x in batch["inputs"].values())
|
||||||
|
|
||||||
inner_totals = 0
|
inner_totals = 0
|
||||||
|
|
||||||
@ -226,9 +318,13 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
|||||||
def cb(inner_num, inner_total):
|
def cb(inner_num, inner_total):
|
||||||
send_batch_progress(inner_num)
|
send_batch_progress(inner_num)
|
||||||
|
|
||||||
return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb)
|
batch_inputs = batch["inputs"]
|
||||||
|
return_values = map_node_over_list(obj, batch_inputs, obj.FUNCTION, allow_interrupt=True, callback=cb)
|
||||||
|
|
||||||
inner_totals += max(len(x) for x in batch.values())
|
if axis_id is None and "axis_id" in batch:
|
||||||
|
axis_id = batch["axis_id"]
|
||||||
|
|
||||||
|
inner_totals += max(len(x) for x in batch_inputs.values())
|
||||||
|
|
||||||
uis = []
|
uis = []
|
||||||
results = []
|
results = []
|
||||||
@ -280,7 +376,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id):
|
|||||||
message["indices"] = input_data_all_batches.indices[batch_num]
|
message["indices"] = input_data_all_batches.indices[batch_num]
|
||||||
server.send_sync("executed", message, server.client_id)
|
server.send_sync("executed", message, server.client_id)
|
||||||
|
|
||||||
return all_outputs, all_outputs_ui
|
return all_outputs, all_outputs_ui, axis_id
|
||||||
|
|
||||||
def format_value(x):
|
def format_value(x):
|
||||||
if x is None:
|
if x is None:
|
||||||
@ -330,8 +426,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
|
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
|
|
||||||
output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
||||||
outputs[unique_id] = output_data_from_batches
|
outputs[unique_id] = {
|
||||||
|
"batches": output_data_from_batches,
|
||||||
|
"axis_id": output_axis_id
|
||||||
|
}
|
||||||
if any(output_ui_from_batches):
|
if any(output_ui_from_batches):
|
||||||
outputs_ui[unique_id] = output_ui_from_batches
|
outputs_ui[unique_id] = output_ui_from_batches
|
||||||
elif unique_id in outputs_ui:
|
elif unique_id in outputs_ui:
|
||||||
@ -356,7 +455,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
|||||||
if input_data_all_batches is not None:
|
if input_data_all_batches is not None:
|
||||||
d = {}
|
d = {}
|
||||||
for batch in input_data_all_batches.batches:
|
for batch in input_data_all_batches.batches:
|
||||||
for name, inputs in batch.items():
|
for name, inputs in batch["inputs"].items():
|
||||||
d[name] = [format_value(x) for x in inputs]
|
d[name] = [format_value(x) for x in inputs]
|
||||||
input_data_formatted.append(d)
|
input_data_formatted.append(d)
|
||||||
|
|
||||||
@ -416,7 +515,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
try:
|
try:
|
||||||
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
for batch in input_data_all_batches.batches:
|
for batch in input_data_all_batches.batches:
|
||||||
if map_node_over_list(class_def, batch, "IS_CHANGED"):
|
if map_node_over_list(class_def, batch["inputs"], "IS_CHANGED"):
|
||||||
is_changed = True
|
is_changed = True
|
||||||
break
|
break
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
@ -757,7 +856,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
|
input_data_all_batches = get_input_data(inputs, obj_class, unique_id)
|
||||||
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
for batch in input_data_all_batches.batches:
|
for batch in input_data_all_batches.batches:
|
||||||
ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS")
|
ret = map_node_over_list(obj_class, batch["inputs"], "VALIDATE_INPUTS")
|
||||||
for r in ret:
|
for r in ret:
|
||||||
if r != True:
|
if r != True:
|
||||||
details = f"{x}"
|
details = f"{x}"
|
||||||
|
|||||||
@ -23,7 +23,7 @@ app.registerExtension({
|
|||||||
nodeType.prototype.onNodeCreated = function () {
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
||||||
|
|
||||||
this.addWidget("button", "Show Grid", "Show Grid", () => {
|
this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => {
|
||||||
const grid = app.nodeGrids[this.id];
|
const grid = app.nodeGrids[this.id];
|
||||||
if (grid == null) {
|
if (grid == null) {
|
||||||
console.warn("No grid to show!");
|
console.warn("No grid to show!");
|
||||||
@ -282,6 +282,14 @@ app.registerExtension({
|
|||||||
|
|
||||||
document.body.appendChild(this._gridPanel);
|
document.body.appendChild(this._gridPanel);
|
||||||
})
|
})
|
||||||
|
|
||||||
|
this.showGridWidget.disabled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const onExecuted = nodeType.prototype.onExecuted;
|
||||||
|
nodeType.prototype.onExecuted = function (output) {
|
||||||
|
const r = onExecuted ? onExecuted.apply(this, arguments) : undefined;
|
||||||
|
this.showGridWidget.disabled = app.nodeGrids[this.id] == null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@ -246,13 +246,25 @@ app.registerExtension({
|
|||||||
else if (inputType === "FLOAT") {
|
else if (inputType === "FLOAT") {
|
||||||
values = values.map(v => parseFloat(v))
|
values = values.map(v => parseFloat(v))
|
||||||
}
|
}
|
||||||
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
|
widget.value = {
|
||||||
|
__inputType__: "combinatorial",
|
||||||
|
values: values,
|
||||||
|
axis_id: axisID,
|
||||||
|
axis_name: axisName,
|
||||||
|
join_axis: Boolean(axisName)
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case "range":
|
case "range":
|
||||||
const isNumberWidget = widget.type === "number" || widget.origType === "number";
|
const isNumberWidget = widget.type === "number" || widget.origType === "number";
|
||||||
if (isNumberWidget) {
|
if (isNumberWidget) {
|
||||||
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
|
values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps);
|
||||||
widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName }
|
widget.value = {
|
||||||
|
__inputType__: "combinatorial",
|
||||||
|
values: values,
|
||||||
|
axis_id: axisID,
|
||||||
|
axis_name: axisName,
|
||||||
|
join_axis: Boolean(axisName)
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case "single":
|
case "single":
|
||||||
|
|||||||
@ -999,14 +999,16 @@ export class ComfyApp {
|
|||||||
});
|
});
|
||||||
|
|
||||||
api.addEventListener("executed", ({ detail }) => {
|
api.addEventListener("executed", ({ detail }) => {
|
||||||
this.nodeOutputs[detail.node] = detail.output;
|
if (detail.batch_num === detail.total_batches) {
|
||||||
if (detail.output != null) {
|
this.nodeOutputs[detail.node] = detail.output;
|
||||||
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
|
if (detail.output != null) {
|
||||||
}
|
this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt)
|
||||||
const node = this.graph.getNodeById(detail.node);
|
}
|
||||||
if (node) {
|
const node = this.graph.getNodeById(detail.node);
|
||||||
if (node.onExecuted)
|
if (node) {
|
||||||
node.onExecuted(detail.output);
|
if (node.onExecuted)
|
||||||
|
node.onExecuted(detail.output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (this.batchProgress != null) {
|
if (this.batchProgress != null) {
|
||||||
this.batchProgress.value = detail.batch_num
|
this.batchProgress.value = detail.batch_num
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user