From 83eb7e505785781577aac2cfcc8f3aa64e4ad79c Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Wed, 19 Jul 2023 23:19:58 -0700 Subject: [PATCH] Validate optional inputs Previously, nodes that were connected to output nodes only via optional inputs were never validated, but were still run. This would cause a Python error and stop execution wherever we happen to be in the execution list. This bug exists on master, but may be more noticeable in this branch because execution order is non-deterministic. Other minor change this commit introduces: `raw_link` can be specified as an option on an input to receive the raw link (in the standard form of [node_id, output_index]) rather than a resolved value. --- execution.py | 126 ++++++++++++++++++++++++++------------------------- 1 file changed, 65 insertions(+), 61 deletions(-) diff --git a/execution.py b/execution.py index 35b542b28..3ceae54eb 100644 --- a/execution.py +++ b/execution.py @@ -20,6 +20,28 @@ class ExecutionResult(Enum): FAILURE = 1 SLEEPING = 2 +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + # ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, # it can still be returned to the graph after having further dependencies added. class ExecutionList: @@ -34,25 +56,7 @@ class ExecutionList: def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - valid_inputs = class_def.INPUT_TYPES() - input_info = None - input_category = None - if input_name in valid_inputs["required"]: - input_category = "required" - input_info = valid_inputs["required"][input_name] - elif input_name in valid_inputs["optional"]: - input_category = "optional" - input_info = valid_inputs["optional"][input_name] - elif input_name in valid_inputs["hidden"]: - input_category = "hidden" - input_info = valid_inputs["hidden"][input_name] - if input_info is None: - return None, None, None - input_type = input_info[0] - extra_info = None - if len(input_info) > 1: - extra_info = input_info[1] - return input_type, input_category, extra_info + return get_input_info(class_def, input_name) def make_input_strong_link(self, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] @@ -87,7 +91,7 @@ class ExecutionList: if isinstance(value, list): from_node_id, from_socket = value input_type, input_category, input_info = self.get_input_info(unique_id, input_name) - if input_info is None or "lazy" not in input_info or not input_info["lazy"]: + if "lazy" not in input_info or not input_info["lazy"]: self.add_strong_link(from_node_id, from_socket, unique_id) def stage_node_execution(self): @@ -155,16 +159,16 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp input_data_all = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + if isinstance(input_data, list) and not input_info.get("raw_link", False): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: continue # This might be a lazily-evaluated input obj = outputs[input_unique_id][output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -571,27 +575,28 @@ def validate_inputs(prompt, item, validated): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] valid = True - for x in required_inputs: + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info = required_inputs[x] - type_input = info[0] + info = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -677,33 +682,32 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id)