From a3b16a6edc6d80d0d2f025453d3a56bb0378ce86 Mon Sep 17 00:00:00 2001 From: hnmr293 Date: Fri, 24 Mar 2023 16:20:15 +0900 Subject: [PATCH] implement optional --- execution.py | 84 ++++++++++++++++++++++++++-------------------- web/scripts/app.js | 44 +++++++++++++----------- 2 files changed, 71 insertions(+), 57 deletions(-) diff --git a/execution.py b/execution.py index 757e0d9f9..8c9fc954d 100644 --- a/execution.py +++ b/execution.py @@ -207,45 +207,55 @@ def validate_inputs(prompt, item): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] - for x in required_inputs: - if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) - val = inputs[x] - info = required_inputs[x] - type_input = info[0] - if isinstance(val, list): - if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) - o_id = val[0] - o_class_type = prompt[o_id]['class_type'] - r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) - if r[0] == False: - return r - else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + + def validate(current_inputs, is_optional=False): + for x in current_inputs: + if x not in inputs: + if is_optional: + return (True, "") + else: + return (False, "Required input is missing. {}, {}".format(class_type, x)) + val = inputs[x] + info = current_inputs[x] + type_input = info[0] + if isinstance(val, list): + if len(val) != 2: + return (False, "Bad Input. {}, {}".format(class_type, x)) + o_id = val[0] + o_class_type = prompt[o_id]['class_type'] + r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES + if r[val[1]] != type_input: + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + r = validate_inputs(prompt, o_id) + if r[0] == False: + return r + else: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) - if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + if len(info) > 1: + if "min" in info[1] and val < info[1]["min"]: + return (False, "Value smaller than min. {}, {}".format(class_type, x)) + if "max" in info[1] and val > info[1]["max"]: + return (False, "Value bigger than max. {}, {}".format(class_type, x)) - if isinstance(type_input, list): - if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + if isinstance(type_input, list): + if val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + return (True, "") + + result, error = validate(class_inputs['required']) + if result: + result, error = validate(class_inputs.get('optional', {}), is_optional=True) + + return result, error def validate_prompt(prompt): outputs = set() diff --git a/web/scripts/app.js b/web/scripts/app.js index fd410cd30..cd9378837 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -565,31 +565,35 @@ class ComfyApp { const nodeData = defs[nodeId]; const node = Object.assign( function ComfyNode() { - const inputs = nodeData["input"]["required"]; - const config = { minWidth: 1, minHeight: 1 }; - for (const inputName in inputs) { - const inputData = inputs[inputName]; - const type = inputData[0]; + function addInputs(self, inputs, config) { + for (const inputName in inputs) { + const inputData = inputs[inputName]; + const type = inputData[0]; - if (Array.isArray(type)) { - // Enums e.g. latent rotation - let defaultValue = type[0]; - if (inputData[1] && inputData[1].default) { - defaultValue = inputData[1].default; + if (Array.isArray(type)) { + // Enums e.g. latent rotation + let defaultValue = type[0]; + if (inputData[1] && inputData[1].default) { + defaultValue = inputData[1].default; + } + self.addWidget("combo", inputName, defaultValue, () => {}, { values: type }); + } else if (`${type}:${inputName}` in widgets) { + // Support custom widgets by Type:Name + Object.assign(config, widgets[`${type}:${inputName}`](self, inputName, inputData, app) || {}); + } else if (type in widgets) { + // Standard type widgets + Object.assign(config, widgets[type](self, inputName, inputData, app) || {}); + } else { + // Node connection inputs + self.addInput(inputName, type); } - this.addWidget("combo", inputName, defaultValue, () => {}, { values: type }); - } else if (`${type}:${inputName}` in widgets) { - // Support custom widgets by Type:Name - Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {}); - } else if (type in widgets) { - // Standard type widgets - Object.assign(config, widgets[type](this, inputName, inputData, app) || {}); - } else { - // Node connection inputs - this.addInput(inputName, type); } } + const config = { minWidth: 1, minHeight: 1 }; + addInputs(this, nodeData["input"]["required"], config); + addInputs(this, nodeData["input"]["optional"], config); + for (const output of nodeData["output"]) { this.addOutput(output, output); }