diff --git a/execution.py b/execution.py index 218a84c36..5f3f011d5 100644 --- a/execution.py +++ b/execution.py @@ -7,26 +7,89 @@ import heapq import traceback import gc import time +import itertools import torch import nodes import comfy.model_management + +def get_input_data_batches(input_data_all): + """Given input data that can contain combinatorial input values, returns all + the possible batches that can be made by combining the different input + values together.""" + + input_to_index = {} + index_to_values = [] + + # Sort by input name first so the order which batch inputs are applied can + # be easily calculated (node execution order first, then alphabetical input + # name second) + sorted_input_names = sorted(input_data_all.keys()) + + i = 0 + for input_name in sorted_input_names: + value = input_data_all[input_name] + if isinstance(value, dict) and "combinatorial" in value: + input_to_index[input_name] = i + index_to_values.append(value["values"]) + i += 1 + + if len(index_to_values) == 0: + # No combinatorial options. + return [input_data_all] + + batches = [] + + for combination in list(itertools.product(*index_to_values)): + batch = {} + for input_name, value in input_data_all.items(): + if isinstance(value, dict) and "combinatorial" in value: + combination_index = input_to_index[input_name] + batch[input_name] = [combination[combination_index]] + else: + # already made into a list by get_input_data + batch[input_name] = value + batches.append(batch) + + return batches + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): + """Given input data from the prompt, returns a list of input data dicts for + each combinatorial batch.""" valid_inputs = class_def.INPUT_TYPES() input_data_all = {} for x in inputs: input_data = inputs[x] + required_or_optional = ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]) if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: return None - obj = outputs[input_unique_id][output_index] - input_data_all[x] = obj + + # This is a list of outputs for each batch of combinatorial inputs. + # Without any combinatorial inputs, it's a list of length 1. + outputs_for_all_batches = outputs[input_unique_id] + + def flatten(list_of_lists): + return list(itertools.chain.from_iterable(list_of_lists)) + + if len(outputs_for_all_batches) == 1: + # Single batch, no combinatorial stuff + input_data_all[x] = outputs_for_all_batches[0][output_index] + else: + # Make the outputs into a list for map-over-list use + # (they are themselves lists so flatten them afterwards) + input_values = [batch_output[output_index] for batch_output in outputs_for_all_batches] + input_values = flatten(input_values) + input_data_all[x] = input_values + elif is_combinatorial_input(input_data): + if required_or_optional: + input_data_all[x] = { "combinatorial": True, "values": input_data["values"] } else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): + if required_or_optional: input_data_all[x] = [input_data] if "hidden" in valid_inputs: @@ -39,7 +102,20 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - return input_data_all + + input_data_all_batches = get_input_data_batches(input_data_all) + + return input_data_all_batches + +def slice_lists_into_dict(d, i): + """ + get a slice of inputs, repeat last input when list isn't long enough + d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 } + """ + d_new = {} + for k, v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists @@ -49,13 +125,23 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): max_len_input = max([len(x) for x in input_data_all.values()]) - # get a slice of inputs, repeat last input when list isn't long enough - def slice_dict(d, i): - d_new = dict() + def format_dict(d): + s = [] for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new - + st = f"{k}: " + if isinstance(v, list): + st += f"list[len: {len(v)}][" + i = [] + for v2 in v: + i.append(v2.__class__.__name__) + st += ",".join(i) + "]" + else: + st += str(type(v)) + s.append(st) + return "( " + ", ".join(s) + " )" + + max_len_input = max(len(x) for x in input_data_all.values()) + results = [] if intput_is_list: if allow_interrupt: @@ -65,42 +151,66 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + results.append(getattr(obj, func)(**slice_lists_into_dict(input_data_all, i))) return results -def get_output_data(obj, input_data_all): - - results = [] - uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) +def get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id): + all_outputs = [] + all_outputs_ui = [] + total_batches = len(input_data_all_batches) - for r in return_values: - if isinstance(r, dict): - if 'ui' in r: - uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) - else: - results.append(r) - - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST + for batch_num, batch in enumerate(input_data_all_batches): + return_values = map_node_over_list(obj, batch, obj.FUNCTION, allow_interrupt=True) - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) + uis = [] + results = [] + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) else: - output.append([o[i] for o in results]) + results.append(r) - ui = dict() - if len(uis) > 0: - ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + output_ui = None + if len(uis) > 0: + output_ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + + all_outputs.append(output) + all_outputs_ui.append(output_ui) + + outputs_ui_to_send = None + if any(all_outputs_ui): + outputs_ui_to_send = all_outputs_ui + + # update the UI after each batch finishes + if server.client_id is not None: + message = { + "node": unique_id, + "output": outputs_ui_to_send, + "prompt_id": prompt_id, + "batch_num": batch_num, + "total_batches": total_batches + } + server.send_sync("executed", message, server.client_id) + + return all_outputs, all_outputs_ui def format_value(x): if x is None: @@ -132,18 +242,18 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id, "total_batches": len(input_data_all_batches) }, server.client_id) obj = class_def() - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui - if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + output_data_from_batches, output_ui_from_batches = get_output_data(obj, input_data_all_batches, server, unique_id, prompt_id) + outputs[unique_id] = output_data_from_batches + if any(output_ui_from_batches): + outputs_ui[unique_id] = output_ui_from_batches + elif unique_id in outputs_ui: + outputs_ui.pop(unique_id) except comfy.model_management.InterruptProcessingException as iex: print("Processing interrupted") @@ -213,13 +323,16 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: + input_data_all_batches = get_input_data(inputs, class_def, unique_id, outputs) + if input_data_all_batches is not None: + try: #is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + for batch in input_data_all_batches: + if map_node_over_list(class_def, batch, "IS_CHANGED"): + is_changed = True + break prompt[unique_id]['is_changed'] = is_changed - except: + except: to_delete = True else: is_changed = prompt[unique_id]['is_changed'] @@ -366,6 +479,94 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() +def is_combinatorial_input(val): + return isinstance(val, dict) and "__inputType__" in val + + +def get_raw_inputs(raw_val): + if isinstance(raw_val, list): + # link to another node + return [raw_val] + elif is_combinatorial_input(raw_val): + return raw_val["values"] + return [raw_val] + + +def clamp_input(val, info, class_type, obj_class, x): + errors = [] + + if is_combinatorial_input(val): + if len(val["values"]) == 0: + error = { + "type": "combinatorial_input_missing_values", + "message": f"Combinatorial input has no values in its list.", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, None, error) + for i, val_choice in enumerate(val["values"]): + r = clamp_input(val_choice, info, class_type, obj_class, x) + if r[0] == False: + return r + val["values"][i] = r[1] + return (True, val, None) + + type_input = info[0] + + try: + if type_input == "INT": + val = int(val) + if type_input == "FLOAT": + val = float(val) + if type_input == "STRING": + val = str(val) + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + return (False, None, error) + + if len(info) > 1: + if "min" in info[1] and val < info[1]["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, None, error) + if "max" in info[1] and val > info[1]["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, None, error) + + return (True, val, None) + + def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -457,107 +658,66 @@ def validate_inputs(prompt, item, validated): validated[o_id] = (False, reasons, o_id) continue else: - try: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val - except Exception as ex: - error = { - "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", - "details": f"{x}, {val}, {ex}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - "exception_message": str(ex) - } - } - errors.append(error) + r = clamp_input(val, info, class_type, obj_class, x) + if r[0] == False: + errors.append(r[2]) continue - - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue + else: + inputs[x] = r[1] if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) + input_data_all_batches = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" + for batch in input_data_all_batches: + ret = map_node_over_list(obj_class, batch, "VALIDATE_INPUTS") + for r in ret: + if r != True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } } - } - errors.append(error) - continue + errors.append(error) + continue else: if isinstance(type_input, list): - if val not in type_input: - input_config = info - list_info = "" + # Account for more than one combinatorial value + raw_vals = get_raw_inputs(val) + for raw_val in raw_vals: + if raw_val not in type_input: + input_config = info + list_info = "" - # Don't send back gigantic lists like if they're lots of - # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" - input_config = None - else: - list_info = str(type_input) + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) - error = { - "type": "value_not_in_list", - "message": "Value not in list", - "details": f"{x}: '{val}' not in {list_info}", - "extra_info": { - "input_name": x, - "input_config": input_config, - "received_value": val, + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{raw_val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": raw_val, + } } - } - errors.append(error) - continue + errors.append(error) + continue if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 4fe0a6013..4113dc9f9 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -195,6 +195,19 @@ app.registerExtension({ this.addOutput("connect to widget input", "*"); this.serialize_widgets = true; this.isVirtualNode = true; + this.properties ||= {} + this.properties.valuesType = "single"; + this.properties.listValue = ""; + this.properties.rangeStepBy = 64; + this.properties.rangeSteps = 2; + } + + getRange(min, stepBy, steps) { + let result = []; + for (let i = 0; i < steps; i++) { + result.push(min + i * stepBy); + } + return result; } applyToGraph() { @@ -209,15 +222,55 @@ app.registerExtension({ if (widgetName) { const widget = node.widgets.find((w) => w.name === widgetName); if (widget) { - widget.value = this.widgets[0].value; + widget.value = this.mainWidget.value; if (widget.callback) { widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } + + let values; + + switch (this.properties.valuesType) { + case "list": + values = this.listWidget.value.split(","); + const inputType = this.outputs[0].widget.config[0] + if (inputType === "INT") { + values = values.map(v => parseInt(v)) + } + else if (inputType === "FLOAT") { + values = values.map(v => parseFloat(v)) + } + widget.value = { __inputType__: "combinatorial", values: values } + break; + case "range": + const isNumberWidget = widget.type === "number" || widget.origType === "number"; + if (isNumberWidget) { + values = this.getRange(widget.value, this.properties.rangeStepBy, this.properties.rangeSteps); + widget.value = { __inputType__: "combinatorial", values: values } + break; + } + case "single": + default: + break; + } } } } } + onPropertyChanged(property, value) { + if (property === "valuesType") { + const isList = value === "list" + if (this.listWidget) + this.listWidget.disabled = !isList + + const isRange = value === "range" + if (this.stepByWidget) + this.stepByWidget.disabled = !isRange + if (this.stepsWidget) + this.stepsWidget.disabled = !isRange + } + } + onConnectionsChange(_, index, connected) { if (connected) { if (this.outputs[0].links?.length) { @@ -227,7 +280,7 @@ app.registerExtension({ if (!this.widgets?.length && this.outputs[0].widget) { // On first load it often cant recreate the widget as the other node doesnt exist yet // Manually recreate it from the output info - this.#createWidget(this.outputs[0].widget.config); + this.mainWidget = this.#createWidget(this.outputs[0].widget.config); } } } else if (!this.outputs[0].links?.length) { @@ -276,7 +329,7 @@ app.registerExtension({ this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget.config, theirNode, widget.name); + this.mainWidget = this.#createWidget(widget.config, theirNode, widget.name); } #createWidget(inputData, node, widgetName) { @@ -304,6 +357,23 @@ app.registerExtension({ addValueControlWidget(this, widget, "fixed"); } + const valuesTypeChoices = ["single", "list"]; + if (widget.type === "number") { + valuesTypeChoices.push("range"); + } + + this.valuesTypeWidget = this.addWidget("combo", "Values type", this.properties.valuesType, "valuesType", { values: valuesTypeChoices }); + + this.listWidget = this.addWidget("text", "Choices", this.properties.listValue, "listValue"); + this.listWidget.disabled = this.properties.valuesType !== "list"; + + if (widget.type === "number") { + this.stepByWidget = this.addWidget("number", "Range Step By", this.properties.rangeStepBy, "rangeStepBy"); + this.stepByWidget.disabled = this.properties.valuesType !== "range"; + this.stepsWidget = this.addWidget("number", "Range Steps", this.properties.rangeSteps, "rangeSteps", { min: 1, max: 128, step: 10 }); + this.stepsWidget.disabled = this.properties.valuesType !== "range"; + } + // When our value changes, update other widgets to reflect our changes // e.g. so LoadImage shows correct image const callback = widget.callback; @@ -328,6 +398,8 @@ app.registerExtension({ this.onResize(this.size); } }); + + return widget; } #isValidConnection(input) {