Sort new inputs before prev outputs or order gets messed up in the frontend

This commit is contained in:
space-nuko 2023-06-09 21:29:39 -05:00
parent a8f3d7a872
commit a99d706d48
3 changed files with 48 additions and 19 deletions

View File

@ -12,6 +12,7 @@ import uuid
from typing import List, Dict
import dataclasses
from dataclasses import dataclass
from functools import cmp_to_key
import torch
import nodes
@ -35,6 +36,10 @@ def find(d, pred):
return None, None
def is_combinatorial_graph_input(value):
return isinstance(value, dict) and "combinatorial" in value
def get_input_data_batches(input_data_all):
"""Given input data that can contain combinatorial input values, returns all
the possible batches that can be made by combining the different input
@ -50,14 +55,34 @@ def get_input_data_batches(input_data_all):
inherit_id = True
axis_id = None
# Sort by input name first so the order which batch inputs are applied can
# be easily calculated (node execution order first, then alphabetical input
# name second)
sorted_input_names = sorted(input_data_all.keys())
# Sort so the images can be reassociated on the frontend.
# Primitive inputs before previous outputs from other nodes, then alphanumerically
def sort_order(a, b):
a_value = input_data_all[a]
b_value = input_data_all[b]
if not (is_combinatorial_graph_input(a_value) and is_combinatorial_graph_input(b_value)):
if is_combinatorial_graph_input(a_value):
return 1
elif is_combinatorial_graph_input(b_value):
return -1
else:
return 1 if a > b else -1
if a_value["order"] == b_value["order"]:
return 1 if a > b else -1
return 1 if a_value["order"] > b_value["order"] else -1
sorted_input_names = sorted(input_data_all.keys(), key=cmp_to_key(sort_order))
from pprint import pp
print("SORTED")
pp(sorted_input_names)
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
if is_combinatorial_graph_input(value):
if "axis_id" in value:
input_to_axis[input_name] = {
"axis_id": value["axis_id"],
@ -78,7 +103,7 @@ def get_input_data_batches(input_data_all):
for input_name in sorted_input_names:
value = input_data_all[input_name]
if isinstance(value, dict) and "combinatorial" in value:
if is_combinatorial_graph_input(value):
if "axis_id" in value:
if axis_id is None:
axis_id = value["axis_id"]
@ -110,7 +135,6 @@ def get_input_data_batches(input_data_all):
if not inherit_id or axis_id is None:
axis_id = str(uuid.uuid4())
from pprint import pp
pp(input_to_index)
pp(input_to_values)
pp(index_to_values)
@ -119,6 +143,7 @@ def get_input_data_batches(input_data_all):
combinations = list(itertools.product(*index_to_values))
pp(indices)
pp(combinations)
for i, indices_set in enumerate(indices):
combination = combinations[i]
@ -197,7 +222,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
# thus if a combinatorial set of outputs is detected, group
# them under the same axis so each of the outputs are
# updated in pairs/triplets/etc. instead of combinatorially
"axis_id": output_data["axis_id"]
"axis_id": output_data["axis_id"],
"order": output_data["execution_order"]
}
input_data_all[x] = input_values
print("--------------------")
@ -206,7 +232,9 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = {
"combinatorial": True,
"values": input_data["values"],
"axis_id": input_data.get("axis_id")
"axis_id": input_data.get("axis_id"),
"is_output": False,
"order": -1 # inputs go before outputs
}
else:
if required_or_optional:
@ -386,7 +414,7 @@ def format_value(x):
else:
return str(x)
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, exec_order):
unique_id = current_item
inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type']
@ -401,7 +429,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, exec_order + 1)
if result[0] is not True:
# Another node failed further upstream
return result
@ -429,7 +457,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
outputs[unique_id] = {
"batches": output_data_from_batches,
"axis_id": output_axis_id
"axis_id": output_axis_id,
"execution_order": exec_order
}
if any(output_ui_from_batches):
outputs_ui[unique_id] = output_ui_from_batches
@ -650,7 +679,7 @@ class PromptExecutor:
# This call shouldn't raise anything if there's an error deep in
# the actual SD code, instead it will report the node where the
# error was raised
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, 0)
if success is not True:
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
break

View File

@ -229,10 +229,9 @@ app.registerExtension({
}
let values;
let axisID = null;
let axisName = null;
let axisID = this.id;
let axisName = `${node.id}_${node.type}: ${widget.name}`;
if (this.properties.axisName != "") {
axisID = this.id;
axisName = this.properties.axisName
}
@ -251,7 +250,7 @@ app.registerExtension({
values: values,
axis_id: axisID,
axis_name: axisName,
join_axis: Boolean(axisName)
join_axis: true
}
break;
case "range":
@ -263,7 +262,7 @@ app.registerExtension({
values: values,
axis_id: axisID,
axis_name: axisName,
join_axis: Boolean(axisName)
join_axis: true
}
break;
}

View File

@ -1075,6 +1075,7 @@ export class ComfyApp {
seen.add(nodeID);
const promptInput = runningPrompt.output[nodeID];
const nodeClass = promptInput.class_type
console.warn("TRAVEL", nodeID, promptInput)
// Ensure input keys are sorted alphanumerically
// This is important for the plot to have the same order as
@ -1083,7 +1084,7 @@ export class ComfyApp {
sortedKeys.sort((a, b) => a.localeCompare(b));
// Then reverse the order since we're traversing the graph upstream,
// so execution order comes out backwards
// so application order of the inputs comes out backwards
sortedKeys = sortedKeys.reverse();
for (const inputName of sortedKeys) {