From 065e2a7e28248118158e9e1dec213c536690e34b Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 14 May 2023 17:35:57 -0500 Subject: [PATCH] Improvement to range widgets --- execution.py | 39 +++++++++++------------------ web/extensions/core/widgetInputs.js | 37 ++++++++++++++++----------- web/scripts/app.js | 3 ++- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/execution.py b/execution.py index a41f4eea6..5f2ce6c75 100644 --- a/execution.py +++ b/execution.py @@ -35,12 +35,7 @@ def get_input_data_batches(input_data_all): batches = [] - print("GET" + str(input_data_all)) - print("ALL " + str(index_to_values)) - print("INPS " + str(input_to_index)) - for combination in list(itertools.product(*index_to_values)): - print("COMBO " + str(combination)) batch = {} for input_name, value in input_data_all.items(): if isinstance(value, dict) and "combinatorial" in value: @@ -102,7 +97,6 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - print("=== GetInputData: " + str(inputs)) input_data_all_batches = get_input_data_batches(input_data_all) @@ -139,11 +133,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): s.append(st) return "( " + ", ".join(s) + " )" - print("+++ Obj: " + str(obj)) - print("+++ Inputs: " + format_dict(input_data_all)) max_len_input = max(len(x) for x in input_data_all.values()) - print("MaxLen " + str(max_len_input)) - print("0 " + str(slice_lists_into_dict(input_data_all, 0))) results = [] if inputs_are_lists: @@ -161,10 +151,8 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): all_outputs = [] all_outputs_ui = [] total_batches = len(input_data_all_batches) - print("TOTAL: " + str(total_batches)) - for i, batch in enumerate(input_data_all_batches): - print("***** BATCH: " + str(i)) + for batch_num, batch in enumerate(input_data_all_batches): return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) uis = [] @@ -193,21 +181,20 @@ def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): else: output.append([o[i] for o in results]) - output_ui = dict() + output_ui = None if len(uis) > 0: output_ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} # update the UI after each batch finishes - if len(output_ui) > 0: - if server.client_id is not None: - message = { - "node": unique_id, - "output": output_ui, - "prompt_id": prompt_id, - "batch": i, - "total_batches": total_batches - } - server.send_sync("executed", message, server.client_id) + if server.client_id is not None: + message = { + "node": unique_id, + "output": output_ui, + "prompt_id": prompt_id, + "batch_num": batch_num, + "total_batches": total_batches + } + server.send_sync("executed", message, server.client_id) all_outputs.append(output) all_outputs_ui.append(output_ui) @@ -234,7 +221,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id) obj = class_def() output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) @@ -413,6 +400,8 @@ def get_raw_inputs(raw_val): def clamp_input(val, info, class_type, obj_class, x): if is_combinatorial_input(val): + if len(val["values"]) == 0: + return (False, "Combinatorial input has no values in its list. {}, {}".format(class_type, x)) for i, val_choice in enumerate(val["values"]): r = clamp_input(val_choice, info, class_type, obj_class, x) if r[0] == False: diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 243ba32d2..4cdce4583 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -196,19 +196,17 @@ app.registerExtension({ this.serialize_widgets = true; this.isVirtualNode = true; this.properties ||= {} - this.properties.isRange = false; - this.properties.rangeMin = 0; - this.properties.rangeMax = 1024; + this.properties.enableRange = false; + this.properties.rangeStepBy = 64; this.properties.rangeSteps = 2; } - getRange(min, max, steps) { - const range = []; - const stepSize = (max - min) / (steps - 1); + getRange(min, stepBy, steps) { + let result = []; for (let i = 0; i < steps; i++) { - range.push(Math.round((min + i * stepSize) * 100) / 100); + result.push(min + i * stepBy); } - return range; + return result; } applyToGraph() { @@ -227,8 +225,9 @@ app.registerExtension({ if (widget.callback) { widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } - if (widget.type === "number" && this.properties.isRange) { - const values = this.getRange(this.properties.rangeMin, this.properties.rangeMax, this.properties.rangeSteps); + const isNumberWidget = widget.type === "number" || widget.origType === "number"; + if (isNumberWidget && this.properties.enableRange) { + const values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps); widget.value = { __inputType__: "combinatorial", values: values } } } @@ -297,6 +296,15 @@ app.registerExtension({ this.#createWidget(widget.config, theirNode, widget.name); } + onPropertyChanged(property, value) { + if (property === "enableRange") { + if (this.stepByWidget) + this.stepByWidget.disabled = !value + if (this.stepsWidget) + this.stepsWidget.disabled = !value + } + } + #createWidget(inputData, node, widgetName) { let type = inputData[0]; @@ -320,10 +328,11 @@ app.registerExtension({ if (widget.type === "number") { addValueControlWidget(this, widget, "fixed"); - this.addWidget("toggle", "Enable Range", this.properties.isRange, "isRange"); - this.addWidget("number", "Range Min.", this.properties.rangeMin, "rangeMin"); - this.addWidget("number", "Range Max.", this.properties.rangeMax, "rangeMax"); - this.addWidget("number", "Range Steps", this.properties.rangeSteps, "rangeSteps", { min: 1, max: 128, step: 10 }); + this.addWidget("toggle", "Enable Range", this.properties.enableRange, "enableRange"); + this.stepByWidget = this.addWidget("number", "Range Step By", this.properties.rangeStepBy, "rangeStepBy"); + this.stepByWidget.disabled = !this.properties.enableRange; + this.stepsWidget = this.addWidget("number", "Range Steps", this.properties.rangeSteps, "rangeSteps", { min: 1, max: 128, step: 10 }); + this.stepsWidget.disabled = !this.properties.enableRange; } // When our value changes, update other widgets to reflect our changes diff --git a/web/scripts/app.js b/web/scripts/app.js index 6e3d66134..6f4c6bd30 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1266,7 +1266,8 @@ export class ComfyApp { for (let i = 0; i < batchCount; i++) { const result = await this.graphToPrompt(); - if (result.totalExecuted > 128 && !confirm("You are about to execute " + result.totalExecuted + " nodes total across " + result.totalCombinatorialNodes + " combinatorial axes. Are you sure you want to do this?")) { + const warnExecutedAmount = 256; + if (result.totalExecuted > warnExecutedAmount && !confirm("You are about to execute " + result.totalExecuted + " nodes total across " + result.totalCombinatorialNodes + " combinatorial axes. Are you sure you want to do this?")) { continue } const p = result.prompt;