diff --git a/execution.py b/execution.py index b55adb87b..0fdd78f9a 100644 --- a/execution.py +++ b/execution.py @@ -8,6 +8,7 @@ import traceback import gc import time import itertools +import uuid from typing import List, Dict import dataclasses from dataclasses import dataclass @@ -40,52 +41,104 @@ def get_input_data_batches(input_data_all): values together.""" input_to_index = {} + input_to_values = {} index_to_values = [] - index_to_axis = {} + input_to_axis = {} index_to_coords = [] + # Axis ID to inherit + inherit_id = True + axis_id = None + # Sort by input name first so the order which batch inputs are applied can # be easily calculated (node execution order first, then alphabetical input # name second) sorted_input_names = sorted(input_data_all.keys()) - i = 0 for input_name in sorted_input_names: value = input_data_all[input_name] if isinstance(value, dict) and "combinatorial" in value: if "axis_id" in value: - found_i = next((k for k, v in index_to_axis.items() if v == value["axis_id"]), None) - else: - found_i = None + input_to_axis[input_name] = { + "axis_id": value["axis_id"], + "join_axis": value.get("join_axis", False) + } - if found_i is not None: - input_to_index[input_name] = found_i + i = 0 + + def add_index(input_name): + nonlocal i, input_data_all, input_to_index, index_to_coords + value = input_data_all[input_name] + input_to_index[input_name] = i + index_to_values.append(value["values"]) + index_to_coords.append(list(range(len(value["values"])))) + ret = i + i += 1 + return ret + + for input_name in sorted_input_names: + value = input_data_all[input_name] + if isinstance(value, dict) and "combinatorial" in value: + if "axis_id" in value: + if axis_id is None: + axis_id = value["axis_id"] + elif axis_id != value["axis_id"]: + inherit_id = False + + found_name = next((k for k, v in input_to_axis.items() if v["axis_id"] == value["axis_id"]), None) else: - input_to_index[input_name] = i - index_to_values.append(value["values"]) - index_to_coords.append(list(range(len(value["values"])))) - if "axis_id" in value: - index_to_axis[i] = value["axis_id"] - i += 1 + inherit_id = False + found_name = None + + if found_name is not None: + join = input_to_axis[found_name]["join_axis"] + found_i = input_to_index.get(found_name) + if found_i is None: + found_i = add_index(found_name) + input_to_index[input_name] = found_i + if not join: + input_to_values[input_name] = value["values"] + else: + add_index(input_name) if len(index_to_values) == 0: # No combinatorial options. - return CombinatorialBatches([input_data_all], input_to_index, index_to_values, None, None) + return CombinatorialBatches([{ "inputs": input_data_all }], input_to_index, index_to_values, None, None) batches = [] + if not inherit_id or axis_id is None: + axis_id = str(uuid.uuid4()) + + from pprint import pp + pp(input_to_index) + pp(input_to_values) + pp(index_to_values) + indices = list(itertools.product(*index_to_coords)) combinations = list(itertools.product(*index_to_values)) - for combination in combinations: + + pp(indices) + + for i, indices_set in enumerate(indices): + combination = combinations[i] batch = {} for input_name, value in input_data_all.items(): if isinstance(value, dict) and "combinatorial" in value: combination_index = input_to_index[input_name] - batch[input_name] = [combination[combination_index]] + index = indices_set[combination_index] + if input_name in input_to_values: + value = input_to_values[input_name][index] + else: + value = combination[combination_index] + batch[input_name] = [value] else: # already made into a list by get_input_data batch[input_name] = value - batches.append(batch) + batches.append({ + "inputs": batch, + "axis_id": axis_id + }) return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations) @@ -103,9 +156,11 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da if input_unique_id not in outputs: return None + output_data = outputs[input_unique_id] + # This is a list of outputs for each batch of combinatorial inputs. # Without any combinatorial inputs, it's a list of length 1. - outputs_for_all_batches = outputs[input_unique_id] + outputs_for_all_batches = output_data["batches"] def flatten(list_of_lists): return list(itertools.chain.from_iterable(list_of_lists)) @@ -114,11 +169,38 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da # Single batch, no combinatorial stuff input_data_all[x] = outputs_for_all_batches[0][output_index] else: + from pprint import pp + print("GETINPUTDATA") + print(x) + print(input_unique_id) # Make the outputs into a list for map-over-list use # (they are themselves lists so flatten them afterwards) input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] - input_values = { "combinatorial": True, "values": flatten(input_values) } + input_values = { + "combinatorial": True, + "values": flatten(input_values), + + # always treat multiple outputs from a node as belonging to + # the same grid "axis". situation this is supposed to prevent: + # + # LoraLoader outputs both a modified CLIP and MODEL. to + # ensure the outputs are enumerated combinatorially with + # others, they should be marked combinatorial. + # + # however, this does *not* mean the executor should + # enumerate every combination of CLIP and MODEL that can + # possibly be output *from the same node*. as in, the CLIP + # from one set of LoRA weights being combined with the MODEL + # from a different set of weights, as you'd never encounter + # that combination with regular use of the LoraLoader node. + # + # thus if a combinatorial set of outputs is detected, group + # them under the same axis so each of the outputs are + # updated in pairs/triplets/etc. instead of combinatorially + "axis_id": output_data["axis_id"] + } input_data_all[x] = input_values + print("--------------------") elif is_combinatorial_input(input_data): if required_or_optional: input_data_all[x] = { @@ -151,18 +233,27 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da st += f"list[len: {len(v)}][" i = [] for v2 in v: - i.append(v2.__class__.__name__) + if isinstance(v2, (int, float, bool)): + i.append(str(v2)) + else: + i.append(v2.__class__.__name__) st += ",".join(i) + "]" else: - st += str(type(v)) + if isinstance(v, (int, float, bool)): + st += str(v) + else: + st += str(type(v)) s.append(st) return "( " + ", ".join(s) + " )" - print("---------------------------------") from pprint import pp for batch in input_data_all_batches.batches: - print(format_dict(batch)); + print(format_dict(batch["inputs"])) + # pp(input_data_all) + # pp(input_data_all_batches.batches) + print(input_data_all_batches.input_to_index) + # print(input_data_all_batches.index_to_values) print("---------------------------------") return input_data_all_batches @@ -202,11 +293,12 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, callbac def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): all_outputs = [] all_outputs_ui = [] + axis_id = None total_batches = len(input_data_all_batches.batches) total_inner_batches = 0 for batch in input_data_all_batches.batches: - total_inner_batches += max(len(x) for x in batch.values()) + total_inner_batches += max(len(x) for x in batch["inputs"].values()) inner_totals = 0 @@ -226,9 +318,13 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): def cb(inner_num, inner_total): send_batch_progress(inner_num) - return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True, callback=cb) + batch_inputs = batch["inputs"] + return_values = map_node_over_list(obj, batch_inputs, obj.FUNCTION, allow_interrupt=True, callback=cb) - inner_totals += max(len(x) for x in batch.values()) + if axis_id is None and "axis_id" in batch: + axis_id = batch["axis_id"] + + inner_totals += max(len(x) for x in batch_inputs.values()) uis = [] results = [] @@ -280,7 +376,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): message["indices"] = input_data_all_batches.indices[batch_num] server.send_sync("executed", message, server.client_id) - return all_outputs, all_outputs_ui + return all_outputs, all_outputs_ui, axis_id def format_value(x): if x is None: @@ -330,8 +426,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute obj = class_def() - output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) - outputs[unique_id] = output_data_from_batches + output_data_from_batches, output_ui_from_batches, output_axis_id = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) + outputs[unique_id] = { + "batches": output_data_from_batches, + "axis_id": output_axis_id + } if any(output_ui_from_batches): outputs_ui[unique_id] = output_ui_from_batches elif unique_id in outputs_ui: @@ -356,7 +455,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute if input_data_all_batches is not None: d = {} for batch in input_data_all_batches.batches: - for name, inputs in batch.items(): + for name, inputs in batch["inputs"].items(): d[name] = [format_value(x) for x in inputs] input_data_formatted.append(d) @@ -416,7 +515,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item try: #is_changed = class_def.IS_CHANGED(**input_data_all) for batch in input_data_all_batches.batches: - if map_node_over_list(class_def, batch, "IS_CHANGED"): + if map_node_over_list(class_def, batch["inputs"], "IS_CHANGED"): is_changed = True break prompt[unique_id]['is_changed'] = is_changed @@ -757,7 +856,7 @@ def validate_inputs(prompt, item, validated): input_data_all_batches = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) for batch in input_data_all_batches.batches: - ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") + ret = map_node_over_list(obj_class, batch["inputs"], "VALIDATE_INPUTS") for r in ret: if r != True: details = f"{x}" diff --git a/web/extensions/core/showGrid.js b/web/extensions/core/showGrid.js index ea178e4db..761052c1d 100644 --- a/web/extensions/core/showGrid.js +++ b/web/extensions/core/showGrid.js @@ -23,7 +23,7 @@ app.registerExtension({ nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; - this.addWidget("button", "Show Grid", "Show Grid", () => { + this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => { const grid = app.nodeGrids[this.id]; if (grid == null) { console.warn("No grid to show!"); @@ -282,6 +282,14 @@ app.registerExtension({ document.body.appendChild(this._gridPanel); }) + + this.showGridWidget.disabled = true; + } + + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (output) { + const r = onExecuted ? onExecuted.apply(this, arguments) : undefined; + this.showGridWidget.disabled = app.nodeGrids[this.id] == null; } } }) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index cedf7fdd9..6205dd40c 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -246,13 +246,25 @@ app.registerExtension({ else if (inputType === "FLOAT") { values = values.map(v => parseFloat(v)) } - widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName } + widget.value = { + __inputType__: "combinatorial", + values: values, + axis_id: axisID, + axis_name: axisName, + join_axis: Boolean(axisName) + } break; case "range": const isNumberWidget = widget.type === "number" || widget.origType === "number"; if (isNumberWidget) { values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps); - widget.value = { __inputType__: "combinatorial", values: values, axis_id: axisID, axis_name: axisName } + widget.value = { + __inputType__: "combinatorial", + values: values, + axis_id: axisID, + axis_name: axisName, + join_axis: Boolean(axisName) + } break; } case "single": diff --git a/web/scripts/app.js b/web/scripts/app.js index 555fe70c5..7ac4e708f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -999,14 +999,16 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - this.nodeOutputs[detail.node] = detail.output; - if (detail.output != null) { - this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt) - } - const node = this.graph.getNodeById(detail.node); - if (node) { - if (node.onExecuted) - node.onExecuted(detail.output); + if (detail.batch_num === detail.total_batches) { + this.nodeOutputs[detail.node] = detail.output; + if (detail.output != null) { + this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt) + } + const node = this.graph.getNodeById(detail.node); + if (node) { + if (node.onExecuted) + node.onExecuted(detail.output); + } } if (this.batchProgress != null) { this.batchProgress.value = detail.batch_num