Sort new inputs before prev outputs or order gets messed up in the frontend

This commit is contained in:
space-nuko 2023-06-09 21:29:39 -05:00
parent a8f3d7a872
commit a99d706d48
3 changed files with 48 additions and 19 deletions

View File

@ -12,6 +12,7 @@ import uuid
from typing import List, Dict from typing import List, Dict
import dataclasses import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from functools import cmp_to_key
import torch import torch
import nodes import nodes
@ -35,6 +36,10 @@ def find(d, pred):
return None, None return None, None
def is_combinatorial_graph_input(value):
return isinstance(value, dict) and "combinatorial" in value
def get_input_data_batches(input_data_all): def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all """Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input the possible batches that can be made by combining the different input
@ -50,14 +55,34 @@ def get_input_data_batches(input_data_all):
inherit_id = True inherit_id = True
axis_id = None axis_id = None
# Sort by input name first so the order which batch inputs are applied can # Sort so the images can be reassociated on the frontend.
# be easily calculated (node execution order first, then alphabetical input # Primitive inputs before previous outputs from other nodes, then alphanumerically
# name second) def sort_order(a, b):
sorted_input_names = sorted(input_data_all.keys()) a_value = input_data_all[a]
b_value = input_data_all[b]
if not (is_combinatorial_graph_input(a_value) and is_combinatorial_graph_input(b_value)):
if is_combinatorial_graph_input(a_value):
return 1
elif is_combinatorial_graph_input(b_value):
return -1
else:
return 1 if a > b else -1
if a_value["order"] == b_value["order"]:
return 1 if a > b else -1
return 1 if a_value["order"] > b_value["order"] else -1
sorted_input_names = sorted(input_data_all.keys(), key=cmp_to_key(sort_order))
from pprint import pp
print("SORTED")
pp(sorted_input_names)
for input_name in sorted_input_names: for input_name in sorted_input_names:
value = input_data_all[input_name] value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value: if is_combinatorial_graph_input(value):
if "axis_id" in value: if "axis_id" in value:
input_to_axis[input_name] = { input_to_axis[input_name] = {
"axis_id": value["axis_id"], "axis_id": value["axis_id"],
@ -78,7 +103,7 @@ def get_input_data_batches(input_data_all):
for input_name in sorted_input_names: for input_name in sorted_input_names:
value = input_data_all[input_name] value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value: if is_combinatorial_graph_input(value):
if "axis_id" in value: if "axis_id" in value:
if axis_id is None: if axis_id is None:
axis_id = value["axis_id"] axis_id = value["axis_id"]
@ -110,7 +135,6 @@ def get_input_data_batches(input_data_all):
if not inherit_id or axis_id is None: if not inherit_id or axis_id is None:
axis_id = str(uuid.uuid4()) axis_id = str(uuid.uuid4())
from pprint import pp
pp(input_to_index) pp(input_to_index)
pp(input_to_values) pp(input_to_values)
pp(index_to_values) pp(index_to_values)
@ -119,6 +143,7 @@ def get_input_data_batches(input_data_all):
combinations = list(itertools.product(*index_to_values)) combinations = list(itertools.product(*index_to_values))
pp(indices) pp(indices)
pp(combinations)
for i, indices_set in enumerate(indices): for i, indices_set in enumerate(indices):
combination = combinations[i] combination = combinations[i]
@ -197,7 +222,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# thus if a combinatorial set of outputs is detected, group # thus if a combinatorial set of outputs is detected, group
# them under the same axis so each of the outputs are # them under the same axis so each of the outputs are
# updated in pairs/triplets/etc. instead of combinatorially # updated in pairs/triplets/etc. instead of combinatorially
"axis_id": output_data["axis_id"] "axis_id": output_data["axis_id"],
"order": output_data["execution_order"]
} }
input_data_all[x] = input_values input_data_all[x] = input_values
print("--------------------") print("--------------------")
@ -206,7 +232,9 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = { input_data_all[x] = {
"combinatorial": True, "combinatorial": True,
"values": input_data["values"], "values": input_data["values"],
"axis_id": input_data.get("axis_id") "axis_id": input_data.get("axis_id"),
"is_output": False,
"order": -1 # inputs go before outputs
} }
else: else:
if required_or_optional: if required_or_optional:
@ -386,7 +414,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, exec_order):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
@ -401,7 +429,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, exec_order + 1)
if result[0] is not True: if result[0] is not True:
# Another node failed further upstream # Another node failed further upstream
return result return result
@ -429,7 +457,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
outputs[unique_id] = { outputs[unique_id] = {
"batches": output_data_from_batches, "batches": output_data_from_batches,
"axis_id": output_axis_id "axis_id": output_axis_id,
"execution_order": exec_order
} }
if any(output_ui_from_batches): if any(output_ui_from_batches):
outputs_ui[unique_id] = output_ui_from_batches outputs_ui[unique_id] = output_ui_from_batches
@ -650,7 +679,7 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in # This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the # the actual SD code, instead it will report the node where the
# error was raised # error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, 0)
if success is not True: if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break break

View File

@ -229,10 +229,9 @@ app.registerExtension({
} }
let values; let values;
let axisID = null; let axisID = this.id;
let axisName = null; let axisName = `${node.id}_${node.type}: ${widget.name}`;
if (this.properties.axisName != "") { if (this.properties.axisName != "") {
axisID = this.id;
axisName = this.properties.axisName axisName = this.properties.axisName
} }
@ -251,7 +250,7 @@ app.registerExtension({
values: values, values: values,
axis_id: axisID, axis_id: axisID,
axis_name: axisName, axis_name: axisName,
join_axis: Boolean(axisName) join_axis: true
} }
break; break;
case "range": case "range":
@ -263,7 +262,7 @@ app.registerExtension({
values: values, values: values,
axis_id: axisID, axis_id: axisID,
axis_name: axisName, axis_name: axisName,
join_axis: Boolean(axisName) join_axis: true
} }
break; break;
} }

View File

@ -1075,6 +1075,7 @@ export class ComfyApp {
seen.add(nodeID); seen.add(nodeID);
const promptInput = runningPrompt.output[nodeID]; const promptInput = runningPrompt.output[nodeID];
const nodeClass = promptInput.class_type const nodeClass = promptInput.class_type
console.warn("TRAVEL", nodeID, promptInput)
// Ensure input keys are sorted alphanumerically // Ensure input keys are sorted alphanumerically
// This is important for the plot to have the same order as // This is important for the plot to have the same order as
@ -1083,7 +1084,7 @@ export class ComfyApp {
sortedKeys.sort((a, b) => a.localeCompare(b)); sortedKeys.sort((a, b) => a.localeCompare(b));
// Then reverse the order since we're traversing the graph upstream, // Then reverse the order since we're traversing the graph upstream,
// so execution order comes out backwards // so application order of the inputs comes out backwards
sortedKeys = sortedKeys.reverse(); sortedKeys = sortedKeys.reverse();
for (const inputName of sortedKeys) { for (const inputName of sortedKeys) {