mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Sort new inputs before prev outputs or order gets messed up in the frontend
This commit is contained in:
parent
a8f3d7a872
commit
a99d706d48
55
execution.py
55
execution.py
@ -12,6 +12,7 @@ import uuid
|
||||
from typing import List, Dict
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from functools import cmp_to_key
|
||||
|
||||
import torch
|
||||
import nodes
|
||||
@ -35,6 +36,10 @@ def find(d, pred):
|
||||
return None, None
|
||||
|
||||
|
||||
def is_combinatorial_graph_input(value):
|
||||
return isinstance(value, dict) and "combinatorial" in value
|
||||
|
||||
|
||||
def get_input_data_batches(input_data_all):
|
||||
"""Given input data that can contain combinatorial input values, returns all
|
||||
the possible batches that can be made by combining the different input
|
||||
@ -50,14 +55,34 @@ def get_input_data_batches(input_data_all):
|
||||
inherit_id = True
|
||||
axis_id = None
|
||||
|
||||
# Sort by input name first so the order which batch inputs are applied can
|
||||
# be easily calculated (node execution order first, then alphabetical input
|
||||
# name second)
|
||||
sorted_input_names = sorted(input_data_all.keys())
|
||||
# Sort so the images can be reassociated on the frontend.
|
||||
# Primitive inputs before previous outputs from other nodes, then alphanumerically
|
||||
def sort_order(a, b):
|
||||
a_value = input_data_all[a]
|
||||
b_value = input_data_all[b]
|
||||
|
||||
if not (is_combinatorial_graph_input(a_value) and is_combinatorial_graph_input(b_value)):
|
||||
if is_combinatorial_graph_input(a_value):
|
||||
return 1
|
||||
elif is_combinatorial_graph_input(b_value):
|
||||
return -1
|
||||
else:
|
||||
return 1 if a > b else -1
|
||||
|
||||
if a_value["order"] == b_value["order"]:
|
||||
return 1 if a > b else -1
|
||||
|
||||
return 1 if a_value["order"] > b_value["order"] else -1
|
||||
|
||||
sorted_input_names = sorted(input_data_all.keys(), key=cmp_to_key(sort_order))
|
||||
|
||||
from pprint import pp
|
||||
print("SORTED")
|
||||
pp(sorted_input_names)
|
||||
|
||||
for input_name in sorted_input_names:
|
||||
value = input_data_all[input_name]
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
if is_combinatorial_graph_input(value):
|
||||
if "axis_id" in value:
|
||||
input_to_axis[input_name] = {
|
||||
"axis_id": value["axis_id"],
|
||||
@ -78,7 +103,7 @@ def get_input_data_batches(input_data_all):
|
||||
|
||||
for input_name in sorted_input_names:
|
||||
value = input_data_all[input_name]
|
||||
if isinstance(value, dict) and "combinatorial" in value:
|
||||
if is_combinatorial_graph_input(value):
|
||||
if "axis_id" in value:
|
||||
if axis_id is None:
|
||||
axis_id = value["axis_id"]
|
||||
@ -110,7 +135,6 @@ def get_input_data_batches(input_data_all):
|
||||
if not inherit_id or axis_id is None:
|
||||
axis_id = str(uuid.uuid4())
|
||||
|
||||
from pprint import pp
|
||||
pp(input_to_index)
|
||||
pp(input_to_values)
|
||||
pp(index_to_values)
|
||||
@ -119,6 +143,7 @@ def get_input_data_batches(input_data_all):
|
||||
combinations = list(itertools.product(*index_to_values))
|
||||
|
||||
pp(indices)
|
||||
pp(combinations)
|
||||
|
||||
for i, indices_set in enumerate(indices):
|
||||
combination = combinations[i]
|
||||
@ -197,7 +222,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
# thus if a combinatorial set of outputs is detected, group
|
||||
# them under the same axis so each of the outputs are
|
||||
# updated in pairs/triplets/etc. instead of combinatorially
|
||||
"axis_id": output_data["axis_id"]
|
||||
"axis_id": output_data["axis_id"],
|
||||
"order": output_data["execution_order"]
|
||||
}
|
||||
input_data_all[x] = input_values
|
||||
print("--------------------")
|
||||
@ -206,7 +232,9 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = {
|
||||
"combinatorial": True,
|
||||
"values": input_data["values"],
|
||||
"axis_id": input_data.get("axis_id")
|
||||
"axis_id": input_data.get("axis_id"),
|
||||
"is_output": False,
|
||||
"order": -1 # inputs go before outputs
|
||||
}
|
||||
else:
|
||||
if required_or_optional:
|
||||
@ -386,7 +414,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui):
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, exec_order):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
@ -401,7 +429,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui)
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, exec_order + 1)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
@ -429,7 +457,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id)
|
||||
outputs[unique_id] = {
|
||||
"batches": output_data_from_batches,
|
||||
"axis_id": output_axis_id
|
||||
"axis_id": output_axis_id,
|
||||
"execution_order": exec_order
|
||||
}
|
||||
if any(output_ui_from_batches):
|
||||
outputs_ui[unique_id] = output_ui_from_batches
|
||||
@ -650,7 +679,7 @@ class PromptExecutor:
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui)
|
||||
success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, 0)
|
||||
if success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
@ -229,10 +229,9 @@ app.registerExtension({
|
||||
}
|
||||
|
||||
let values;
|
||||
let axisID = null;
|
||||
let axisName = null;
|
||||
let axisID = this.id;
|
||||
let axisName = `${node.id}_${node.type}: ${widget.name}`;
|
||||
if (this.properties.axisName != "") {
|
||||
axisID = this.id;
|
||||
axisName = this.properties.axisName
|
||||
}
|
||||
|
||||
@ -251,7 +250,7 @@ app.registerExtension({
|
||||
values: values,
|
||||
axis_id: axisID,
|
||||
axis_name: axisName,
|
||||
join_axis: Boolean(axisName)
|
||||
join_axis: true
|
||||
}
|
||||
break;
|
||||
case "range":
|
||||
@ -263,7 +262,7 @@ app.registerExtension({
|
||||
values: values,
|
||||
axis_id: axisID,
|
||||
axis_name: axisName,
|
||||
join_axis: Boolean(axisName)
|
||||
join_axis: true
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -1075,6 +1075,7 @@ export class ComfyApp {
|
||||
seen.add(nodeID);
|
||||
const promptInput = runningPrompt.output[nodeID];
|
||||
const nodeClass = promptInput.class_type
|
||||
console.warn("TRAVEL", nodeID, promptInput)
|
||||
|
||||
// Ensure input keys are sorted alphanumerically
|
||||
// This is important for the plot to have the same order as
|
||||
@ -1083,7 +1084,7 @@ export class ComfyApp {
|
||||
sortedKeys.sort((a, b) => a.localeCompare(b));
|
||||
|
||||
// Then reverse the order since we're traversing the graph upstream,
|
||||
// so execution order comes out backwards
|
||||
// so application order of the inputs comes out backwards
|
||||
sortedKeys = sortedKeys.reverse();
|
||||
|
||||
for (const inputName of sortedKeys) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user