From e4be8e06668ac1dcce24e4aa08a17b72c57e9427 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 29 May 2023 14:56:40 -0500 Subject: [PATCH] Allow specifying which inputs should be used as lists --- comfy_extras/nodes_rebatch.py | 9 +- execution.py | 445 +++++++++++++++++------------ nodes.py | 6 +- web/extensions/core/uploadImage.js | 3 - web/scripts/app.js | 21 +- web/scripts/widgets.js | 37 ++- 6 files changed, 296 insertions(+), 225 deletions(-) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 0a9daf272..46c9c3105 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -3,11 +3,10 @@ import torch class LatentRebatch: @classmethod def INPUT_TYPES(s): - return {"required": { "latents": ("LATENT",), + return {"required": { "latents": ("LATENT", { "is_list": True }), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), }} - RETURN_TYPES = ("LATENT",) - INPUT_IS_LIST = True + RETURN_TYPES = ("LATENT", ) OUTPUT_IS_LIST = (True, ) FUNCTION = "rebatch" @@ -54,8 +53,6 @@ class LatentRebatch: return result def rebatch(self, latents, batch_size): - batch_size = batch_size[0] - output_list = [] current_batch = (None, None, None) processed = 0 @@ -105,4 +102,4 @@ NODE_CLASS_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file +} diff --git a/execution.py b/execution.py index 218a84c36..989c1a5b5 100644 --- a/execution.py +++ b/execution.py @@ -13,21 +13,44 @@ import nodes import comfy.model_management +def slice_lists_into_dict(d, i): + """ + get a slice of inputs, repeat last input when list isn't long enough + d={ "seed": [ 1, 2, 3 ], "steps": [ 4, 8 ] }, i=2 -> { "seed": 3, "steps": 8 } + """ + d_new = {} + for k, v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} + required = valid_inputs.get("required", {}) + optional = valid_inputs.get("optional", {}) for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + input_type = input_data["type"] + + input_def = required.get(x) + if input_def is None: + input_def = optional.get(x) + + use_value_as_list = input_def is not None and len(input_def) > 1 and input_def[1].get("is_list", False) + + if input_type == "link": + input_unique_id = input_data["origin_id"] + output_index = input_data["origin_slot"] if input_unique_id not in outputs: return None obj = outputs[input_unique_id][output_index] + if use_value_as_list: + obj = [obj] input_data_all[x] = obj else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + value = input_data["value"] + if input_def is not None: + input_data_all[x] = [value] if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -39,37 +62,23 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] + return input_data_all def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): - # check if node wants the lists - intput_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - intput_is_list = obj.INPUT_IS_LIST - - max_len_input = max([len(x) for x in input_data_all.values()]) - - # get a slice of inputs, repeat last input when list isn't long enough - def slice_dict(d, i): - d_new = dict() - for k,v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new - results = [] - if intput_is_list: + max_len_input = max([len(x) for x in input_data_all.values()]) + + for i in range(max_len_input): if allow_interrupt: nodes.before_node_execution() - results.append(getattr(obj, func)(**input_data_all)) - else: - for i in range(max_len_input): - if allow_interrupt: - nodes.before_node_execution() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + + args = slice_lists_into_dict(input_data_all, i) + results.append(getattr(obj, func)(**args)) + return results def get_output_data(obj, input_data_all): - results = [] uis = [] return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) @@ -120,10 +129,10 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute for x in inputs: input_data = inputs[x] + input_type = input_data["type"] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + if input_type == "link": + input_unique_id = input_data["origin_id"] if input_unique_id not in outputs: result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) if result[0] is not True: @@ -192,9 +201,9 @@ def recursive_will_execute(prompt, outputs, current_item): for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + input_type = input_data["type"] + if input_type == "link": + input_unique_id = input_data["origin_id"] if input_unique_id not in outputs: will_execute += recursive_will_execute(prompt, outputs, input_unique_id) @@ -235,10 +244,10 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item elif inputs == old_prompt[unique_id]['inputs']: for x in inputs: input_data = inputs[x] + input_type = input_data["type"] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] + if input_type == "link": + input_unique_id = input_data["origin_id"] if input_unique_id in outputs: to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) else: @@ -366,6 +375,150 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() +def validate_link(prompt, x, val, info, validated): + type_input = info[0] + + o_id = val.get("origin_id", None) + o_slot = val.get("origin_slot", None) + + if o_id is None or o_slot is None: + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a dictionary like { type: 'link', origin_id: 1, origin_slot: 1 }", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + return (False, error) + + o_class_type = prompt[o_id]['class_type'] + r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES + if r[o_slot] != type_input: + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type, + "linked_node": val + } + } + return (False, error) + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + return (False, None) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", + "details": str(ex), + "extra_info": { + "input_name": x, + "input_config": info, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "linked_node": val, + "linked_node_inputs": prompt[o_id] + } + }] + validated[o_id] = (False, reasons, o_id) + return (False, None) + + return (True, val) + + +def validate_value(inputs, unique_id, x, val, info, obj_class): + type_input = info[0] + result_val = val + + try: + if type_input == "INT": + result_val = int(val) + if type_input == "FLOAT": + result_val = float(val) + if type_input == "STRING": + result_val = str(val) + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + return (False, error) + + if len(info) > 1: + if "min" in info[1] and val < info[1]["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, error) + if "max" in info[1] and val > info[1]["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + return (False, error) + else: + # Validate combo widget + if isinstance(type_input, list): + if val not in type_input: + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + return (False, error) + + return (True, result_val) + + def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -396,168 +549,84 @@ def validate_inputs(prompt, item, validated): val = inputs[x] info = required_inputs[x] - type_input = info[0] - if isinstance(val, list): - if len(val) != 2: - error = { - "type": "bad_linked_input", - "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val - } + + input_type = None + if isinstance(val, dict): + input_type = val.get("type", None) + + if input_type not in ["link", "value"]: + error = { + "type": "bad_input_format", + "message": "Bad input format, must be a dictionary with 'type' set to 'link' or 'value'", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val } - errors.append(error) + } + errors.append(error) + continue + + if input_type == "link": + result = validate_link(prompt, x, val, info, validated) + if result[0] is False: + valid = False + if result[1] is not None: + errors.append(result[1]) continue - o_id = val[0] - o_class_type = prompt[o_id]['class_type'] - r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: - received_type = r[val[1]] - details = f"{x}, {received_type} != {type_input}" + inputs[x] = result[1] + + elif input_type == "value": + inner_val = val.get("value", None) + if inner_val is None: error = { - "type": "return_type_mismatch", - "message": "Return type mismatch between linked nodes", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_type": received_type, - "linked_node": val - } - } - errors.append(error) - continue - try: - r = validate_inputs(prompt, o_id, validated) - if r[0] is False: - # `r` will be set in `validated[o_id]` already - valid = False - continue - except Exception as ex: - typ, _, tb = sys.exc_info() - valid = False - exception_type = full_type_name(typ) - reasons = [{ - "type": "exception_during_inner_validation", - "message": "Exception when validating inner node", - "details": str(ex), - "extra_info": { - "input_name": x, - "input_config": info, - "exception_message": str(ex), - "exception_type": exception_type, - "traceback": traceback.format_tb(tb), - "linked_node": val - } - }] - validated[o_id] = (False, reasons, o_id) - continue - else: - try: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val - except Exception as ex: - error = { - "type": "invalid_input_type", - "message": f"Failed to convert an input value to a {type_input} value", - "details": f"{x}, {val}, {ex}", + "type": "bad_value_input", + "message": "Bad value input, must be a dictionary like { type: 'value', value: 42 }", + "details": f"{x}, {val}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, - "exception_message": str(ex) + } + } + return (False, error) + + result = validate_value(inputs, unique_id, x, inner_val, info, obj_class) + + if result[0] is False: + errors.append(result[1]) + continue + + inputs[x] = { "type": "value", "value": result[1] } + + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for i, r in enumerate(ret): + if r is not True: + details = "" + if r is not False: + details += str(r) + + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputList in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputList] + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_config": info, + "received_inputs": input_data_formatted, } } errors.append(error) - continue - - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: - if isinstance(type_input, list): - if val not in type_input: - input_config = info - list_info = "" - - # Don't send back gigantic lists like if they're lots of - # scanned model filepaths - if len(type_input) > 20: - list_info = f"(list of length {len(type_input)})" - input_config = None - else: - list_info = str(type_input) - - error = { - "type": "value_not_in_list", - "message": "Value not in list", - "details": f"{x}: '{val}' not in {list_info}", - "extra_info": { - "input_name": x, - "input_config": input_config, - "received_value": val, - } - } - errors.append(error) - continue if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) @@ -644,7 +713,7 @@ def validate_prompt(prompt): node_errors[node_id]["dependent_outputs"].append(o) print("Output will be ignored") - if len(good_outputs) == 0: + if len(good_outputs) == 0 or node_errors: errors_list = [] for o, errors in errors: for error in errors: diff --git a/nodes.py b/nodes.py index 6f05e4b77..ae3c784bf 100644 --- a/nodes.py +++ b/nodes.py @@ -1085,7 +1085,7 @@ class LoadImage: input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(files), { "forceInput": True })}, + {"image": (sorted(files), )}, } CATEGORY = "image" @@ -1127,7 +1127,7 @@ class LoadImageBatch: input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"images": (sorted(files), )}, + {"images": ("MULTIIMAGEUPLOAD", { "filepaths": sorted(files) } )}, } CATEGORY = "image" @@ -1135,7 +1135,6 @@ class LoadImageBatch: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_images" - INPUT_IS_LIST = True OUTPUT_IS_LIST = (True, True, ) def load_images(self, images): @@ -1437,6 +1436,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "PreviewImage": "Preview Image", "LoadImage": "Load Image", "LoadImageMask": "Load Image (as Mask)", + "LoadImageBatch": "Load Image Batch", "ImageScale": "Upscale Image", "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", diff --git a/web/extensions/core/uploadImage.js b/web/extensions/core/uploadImage.js index e2ecfae86..3584364a3 100644 --- a/web/extensions/core/uploadImage.js +++ b/web/extensions/core/uploadImage.js @@ -10,9 +10,6 @@ app.registerExtension({ case "LoadImageMask": nodeData.input.required.upload = ["IMAGEUPLOAD"]; break; - case "LoadImageBatch": - nodeData.input.required.upload = ["MULTIIMAGEUPLOAD"]; - break; } }, }); diff --git a/web/scripts/app.js b/web/scripts/app.js index fd1186ab9..f329ab131 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1110,20 +1110,21 @@ export class ComfyApp { for (const inputName in inputs) { const inputData = inputs[inputName]; const type = inputData[0]; - const inputShape = nodeData["input_is_list"] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE; + const options = inputData[1] || {}; + const inputShape = options.is_list ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE; if(inputData[1]?.forceInput) { this.addInput(inputName, type, { shape: inputShape }); } else { if (Array.isArray(type)) { // Enums - Object.assign(config, widgets.COMBO(this, inputName, inputData, nodeData, app) || {}); + Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {}); } else if (`${type}:${inputName}` in widgets) { // Support custom widgets by Type:Name - Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, nodeData, app) || {}); + Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {}); } else if (type in widgets) { // Standard type widgets - Object.assign(config, widgets[type](this, inputName, inputData, nodeData, app) || {}); + Object.assign(config, widgets[type](this, inputName, inputData, app) || {}); } else { // Node connection inputs this.addInput(inputName, type, { shape: inputShape }); @@ -1313,7 +1314,8 @@ export class ComfyApp { for (const i in widgets) { const widget = widgets[i]; if (!widget.options || widget.options.serialize !== false) { - inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + const value = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + inputs[widget.name] = { type: "value", value } } } } @@ -1333,7 +1335,11 @@ export class ComfyApp { } if (link) { - inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + inputs[node.inputs[i].name] = { + type: "link", + origin_id: String(link.origin_id), + origin_slot: parseInt(link.origin_slot) + }; } } } @@ -1377,6 +1383,9 @@ export class ComfyApp { message += "\n" + nodeError.class_type + ":" for (const errorReason of nodeError.errors) { message += "\n - " + errorReason.message + ": " + errorReason.details + if (errorReason.extra_info?.traceback) { + message += "\n" + errorReason.extra_info.traceback.join("") + } } } return message diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 784e2740d..76ca8c755 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -268,7 +268,7 @@ const INT = (node, inputName, inputData) => { }; } -const STRING = (node, inputName, inputData, nodeData, app) => { +const STRING = (node, inputName, inputData, app) => { const defaultVal = inputData[1].default || ""; const multiline = !!inputData[1].multiline; @@ -279,14 +279,15 @@ const STRING = (node, inputName, inputData, nodeData, app) => { } } -const COMBO = (node, inputName, inputData, nodeData) => { +const COMBO = (node, inputName, inputData) => { const type = inputData[0]; let defaultValue = type[0]; - if (inputData[1] && inputData[1].default) { - defaultValue = inputData[1].default; + let options = inputData[1] || {} + if (options.default) { + defaultValue = options.default } - if (nodeData["input_is_list"]) { + if (options.is_list) { defaultValue = [defaultValue] const widget = node.addWidget("text", inputName, defaultValue, () => {}, { values: type }) widget.disabled = true; @@ -297,7 +298,7 @@ const COMBO = (node, inputName, inputData, nodeData) => { } } -const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => { +const IMAGEUPLOAD = (node, inputName, inputData, app) => { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; @@ -412,8 +413,7 @@ const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => { uploadWidget = node.addWidget("button", "choose file to upload", "image", () => { fileInput.value = null; fileInput.click(); - }); - uploadWidget.serialize = false; + }, { serialize: false }); // Add handler to check if an image is being dragged over our node node.onDragOver = function (e) { @@ -442,8 +442,14 @@ const IMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => { return { widget: uploadWidget }; } -const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => { - const imagesWidget = node.widgets.find((w) => w.name === "images"); +const MULTIIMAGEUPLOAD = (node, inputName, inputData, app) => { + const imagesWidget = node.addWidget("text", inputName, inputData, () => {}) + + imagesWidget._filepaths = [] + if (inputData[1] && inputData[1].filepaths) { + imagesWidget._filepaths = inputData[1].filepaths + } + let uploadWidget; let clearWidget; @@ -534,11 +540,6 @@ const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => { if (resp.status === 200) { const data = await resp.json(); - // Add the file as an option and update the widget value - if (!imagesWidget.options.values.includes(data.name)) { - imagesWidget.options.values.push(data.name); - } - if (updateNode) { imagesWidget.value.push(data.name) } @@ -573,14 +574,12 @@ const MULTIIMAGEUPLOAD = (node, inputName, inputData, nodeData, app) => { uploadWidget = node.addWidget("button", "choose files to upload", "images", () => { fileInput.value = null; fileInput.click(); - }); - uploadWidget.serialize = false; + }, { serialize: false }); clearWidget = node.addWidget("button", "clear all uploads", "images", () => { imagesWidget.value = [] showImages(imagesWidget.value); - }); - clearWidget.serialize = false; + }, { serialize: false }); // Add handler to check if an image is being dragged over our node node.onDragOver = function (e) {