From a99d706d48b23a90f2c6b45e57ad0b2e81ed488c Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 9 Jun 2023 21:29:39 -0500 Subject: [PATCH] Sort new inputs before prev outputs or order gets messed up in the frontend --- execution.py | 55 ++++++++++++++++++++++------- web/extensions/core/widgetInputs.js | 9 +++-- web/scripts/app.js | 3 +- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/execution.py b/execution.py index 0fdd78f9a..9ee90adf4 100644 --- a/execution.py +++ b/execution.py @@ -12,6 +12,7 @@ import uuid from typing import List, Dict import dataclasses from dataclasses import dataclass +from functools import cmp_to_key import torch import nodes @@ -35,6 +36,10 @@ def find(d, pred): return None, None +def is_combinatorial_graph_input(value): + return isinstance(value, dict) and "combinatorial" in value + + def get_input_data_batches(input_data_all): """Given input data that can contain combinatorial input values, returns all the possible batches that can be made by combining the different input @@ -50,14 +55,34 @@ def get_input_data_batches(input_data_all): inherit_id = True axis_id = None - # Sort by input name first so the order which batch inputs are applied can - # be easily calculated (node execution order first, then alphabetical input - # name second) - sorted_input_names = sorted(input_data_all.keys()) + # Sort so the images can be reassociated on the frontend. + # Primitive inputs before previous outputs from other nodes, then alphanumerically + def sort_order(a, b): + a_value = input_data_all[a] + b_value = input_data_all[b] + + if not (is_combinatorial_graph_input(a_value) and is_combinatorial_graph_input(b_value)): + if is_combinatorial_graph_input(a_value): + return 1 + elif is_combinatorial_graph_input(b_value): + return -1 + else: + return 1 if a > b else -1 + + if a_value["order"] == b_value["order"]: + return 1 if a > b else -1 + + return 1 if a_value["order"] > b_value["order"] else -1 + + sorted_input_names = sorted(input_data_all.keys(), key=cmp_to_key(sort_order)) + + from pprint import pp + print("SORTED") + pp(sorted_input_names) for input_name in sorted_input_names: value = input_data_all[input_name] - if isinstance(value, dict) and "combinatorial" in value: + if is_combinatorial_graph_input(value): if "axis_id" in value: input_to_axis[input_name] = { "axis_id": value["axis_id"], @@ -78,7 +103,7 @@ def get_input_data_batches(input_data_all): for input_name in sorted_input_names: value = input_data_all[input_name] - if isinstance(value, dict) and "combinatorial" in value: + if is_combinatorial_graph_input(value): if "axis_id" in value: if axis_id is None: axis_id = value["axis_id"] @@ -110,7 +135,6 @@ def get_input_data_batches(input_data_all): if not inherit_id or axis_id is None: axis_id = str(uuid.uuid4()) - from pprint import pp pp(input_to_index) pp(input_to_values) pp(index_to_values) @@ -119,6 +143,7 @@ def get_input_data_batches(input_data_all): combinations = list(itertools.product(*index_to_values)) pp(indices) + pp(combinations) for i, indices_set in enumerate(indices): combination = combinations[i] @@ -197,7 +222,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da # thus if a combinatorial set of outputs is detected, group # them under the same axis so each of the outputs are # updated in pairs/triplets/etc. instead of combinatorially - "axis_id": output_data["axis_id"] + "axis_id": output_data["axis_id"], + "order": output_data["execution_order"] } input_data_all[x] = input_values print("--------------------") @@ -206,7 +232,9 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = { "combinatorial": True, "values": input_data["values"], - "axis_id": input_data.get("axis_id") + "axis_id": input_data.get("axis_id"), + "is_output": False, + "order": -1 # inputs go before outputs } else: if required_or_optional: @@ -386,7 +414,7 @@ def format_value(x): else: return str(x) -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui, exec_order): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -401,7 +429,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui, exec_order + 1) if result[0] is not True: # Another node failed further upstream return result @@ -429,7 +457,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) outputs[unique_id] = { "batches": output_data_from_batches, - "axis_id": output_axis_id + "axis_id": output_axis_id, + "execution_order": exec_order } if any(output_ui_from_batches): outputs_ui[unique_id] = output_ui_from_batches @@ -650,7 +679,7 @@ class PromptExecutor: # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, 0) if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 6205dd40c..ef78f6886 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -229,10 +229,9 @@ app.registerExtension({ } let values; - let axisID = null; - let axisName = null; + let axisID = this.id; + let axisName = `${node.id}_${node.type}: ${widget.name}`; if (this.properties.axisName != "") { - axisID = this.id; axisName = this.properties.axisName } @@ -251,7 +250,7 @@ app.registerExtension({ values: values, axis_id: axisID, axis_name: axisName, - join_axis: Boolean(axisName) + join_axis: true } break; case "range": @@ -263,7 +262,7 @@ app.registerExtension({ values: values, axis_id: axisID, axis_name: axisName, - join_axis: Boolean(axisName) + join_axis: true } break; } diff --git a/web/scripts/app.js b/web/scripts/app.js index 7ac4e708f..d1423b6f4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1075,6 +1075,7 @@ export class ComfyApp { seen.add(nodeID); const promptInput = runningPrompt.output[nodeID]; const nodeClass = promptInput.class_type + console.warn("TRAVEL", nodeID, promptInput) // Ensure input keys are sorted alphanumerically // This is important for the plot to have the same order as @@ -1083,7 +1084,7 @@ export class ComfyApp { sortedKeys.sort((a, b) => a.localeCompare(b)); // Then reverse the order since we're traversing the graph upstream, - // so execution order comes out backwards + // so application order of the inputs comes out backwards sortedKeys = sortedKeys.reverse(); for (const inputName of sortedKeys) {