diff --git a/execution.py b/execution.py index b55adb87b..9cefe7959 100644 --- a/execution.py +++ b/execution.py @@ -87,8 +87,39 @@ def get_input_data_batches(input_data_all): batch[input_name] = value batches.append(batch) + print("------------------=+++++++++++++++++") + for batch in batches: + print(format_dict(batch)) + print(format_dict(input_to_index)) + print(format_dict({ "v": index_to_values })) + print(index_to_coords) + print("------------------=+++++++++++++++++") + return CombinatorialBatches(batches, input_to_index, index_to_values, indices, combinations) + +def format_dict(d): + s = [] + for k,v in d.items(): + st = f"{k}: " + if isinstance(v, list): + st += f"list[len: {len(v)}][" + i = [] + for v2 in v: + if isinstance(v2, (int, float, bool)): + i.append(str(v2)) + else: + i.append(v2.__class__.__name__) + st += ",".join(i) + "]" + else: + if isinstance(v, (int, float, bool)): + st += str(v) + else: + st += str(type(v)) + s.append(st) + return "( " + ", ".join(s) + " )" + + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): """Given input data from the prompt, returns a list of input data dicts for each combinatorial batch.""" @@ -117,7 +148,15 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da # Make the outputs into a list for map-over-list use # (they are themselves lists so flatten them afterwards) input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] - input_values = { "combinatorial": True, "values": flatten(input_values) } + print("COMB") + print(str(input_unique_id)) + print(str(output_index)) + print(format_dict({ "values": input_values })) + input_values = { + "combinatorial": True, + "values": flatten(input_values), + "axis_id": prompt[input_unique_id].get("axis_id") + } input_data_all[x] = input_values elif is_combinatorial_input(input_data): if required_or_optional: @@ -143,22 +182,6 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all_batches = get_input_data_batches(input_data_all) - def format_dict(d): - s = [] - for k,v in d.items(): - st = f"{k}: " - if isinstance(v, list): - st += f"list[len: {len(v)}][" - i = [] - for v2 in v: - i.append(v2.__class__.__name__) - st += ",".join(i) + "]" - else: - st += str(type(v)) - s.append(st) - return "( " + ", ".join(s) + " )" - - print("---------------------------------") from pprint import pp for batch in input_data_all_batches.batches: @@ -274,7 +297,7 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): "output": outputs_ui_to_send, "prompt_id": prompt_id, "batch_num": inner_totals, - "total_batches": total_inner_batches + "total_batches": total_inner_batches, } if input_data_all_batches.indices: message["indices"] = input_data_all_batches.indices[batch_num] @@ -411,7 +434,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: - input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs) + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt) if input_data_all_batches is not None: try: #is_changed = class_def.IS_CHANGED(**input_data_all) @@ -754,7 +777,7 @@ def validate_inputs(prompt, item, validated): inputs[x] = r[1] if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all_batches = get_input_data(inputs, obj_class, unique_id) + input_data_all_batches = get_input_data(inputs, obj_class, unique_id, {}, prompt) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) for batch in input_data_all_batches.batches: ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") diff --git a/web/extensions/core/showGrid.js b/web/extensions/core/showGrid.js index ea178e4db..761052c1d 100644 --- a/web/extensions/core/showGrid.js +++ b/web/extensions/core/showGrid.js @@ -23,7 +23,7 @@ app.registerExtension({ nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; - this.addWidget("button", "Show Grid", "Show Grid", () => { + this.showGridWidget = this.addWidget("button", "Show Grid", "Show Grid", () => { const grid = app.nodeGrids[this.id]; if (grid == null) { console.warn("No grid to show!"); @@ -282,6 +282,14 @@ app.registerExtension({ document.body.appendChild(this._gridPanel); }) + + this.showGridWidget.disabled = true; + } + + const onExecuted = nodeType.prototype.onExecuted; + nodeType.prototype.onExecuted = function (output) { + const r = onExecuted ? onExecuted.apply(this, arguments) : undefined; + this.showGridWidget.disabled = app.nodeGrids[this.id] == null; } } }) diff --git a/web/scripts/app.js b/web/scripts/app.js index 555fe70c5..2aa4e3393 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -999,14 +999,16 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - this.nodeOutputs[detail.node] = detail.output; - if (detail.output != null) { - this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt) - } - const node = this.graph.getNodeById(detail.node); - if (node) { - if (node.onExecuted) - node.onExecuted(detail.output); + if (detail.batch_num === detail.total_batches) { + this.nodeOutputs[detail.node] = detail.output; + if (detail.output != null) { + this.nodeGrids[detail.node] = this.#resolveGrid(detail.node, detail.output, this.runningPrompt) + } + const node = this.graph.getNodeById(detail.node); + if (node) { + if (node.onExecuted) + node.onExecuted(detail.output); + } } if (this.batchProgress != null) { this.batchProgress.value = detail.batch_num @@ -1473,6 +1475,7 @@ export class ComfyApp { const inputs = {}; const widgets = node.widgets; + let axis_id = null; // Store all widget values if (widgets) { @@ -1484,6 +1487,13 @@ export class ComfyApp { if (typeof widgetValue === "object" && widgetValue.__inputType__) { totalCombinatorialNodes += 1; executionFactor *= widgetValue.values.length; + + if (widgetValue.axis_id != null) { + if (axis_id != null && axis_id != widgetValue.axis_id) { + throw new RuntimeError("Each node's outputs can only belong to one axis at a time"); + } + axis_id = widgetValue.axis_id; + } } totalExecuted += executionFactor; } @@ -1513,6 +1523,7 @@ export class ComfyApp { output[String(node.id)] = { inputs, class_type: node.comfyClass, + axis_id, }; }