diff --git a/execution.py b/execution.py index 35b542b28..3ceae54eb 100644 --- a/execution.py +++ b/execution.py @@ -20,6 +20,28 @@ class ExecutionResult(Enum): FAILURE = 1 SLEEPING = 2 +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + # ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, # it can still be returned to the graph after having further dependencies added. class ExecutionList: @@ -34,25 +56,7 @@ class ExecutionList: def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - valid_inputs = class_def.INPUT_TYPES() - input_info = None - input_category = None - if input_name in valid_inputs["required"]: - input_category = "required" - input_info = valid_inputs["required"][input_name] - elif input_name in valid_inputs["optional"]: - input_category = "optional" - input_info = valid_inputs["optional"][input_name] - elif input_name in valid_inputs["hidden"]: - input_category = "hidden" - input_info = valid_inputs["hidden"][input_name] - if input_info is None: - return None, None, None - input_type = input_info[0] - extra_info = None - if len(input_info) > 1: - extra_info = input_info[1] - return input_type, input_category, extra_info + return get_input_info(class_def, input_name) def make_input_strong_link(self, to_node_id, to_input): inputs = self.dynprompt.get_node(to_node_id)["inputs"] @@ -87,7 +91,7 @@ class ExecutionList: if isinstance(value, list): from_node_id, from_socket = value input_type, input_category, input_info = self.get_input_info(unique_id, input_name) - if input_info is None or "lazy" not in input_info or not input_info["lazy"]: + if "lazy" not in input_info or not input_info["lazy"]: self.add_strong_link(from_node_id, from_socket, unique_id) def stage_node_execution(self): @@ -155,16 +159,16 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp input_data_all = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + if isinstance(input_data, list) and not input_info.get("raw_link", False): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: continue # This might be a lazily-evaluated input obj = outputs[input_unique_id][output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -571,27 +575,28 @@ def validate_inputs(prompt, item, validated): obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) errors = [] valid = True - for x in required_inputs: + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info = required_inputs[x] - type_input = info[0] + info = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -677,33 +682,32 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if len(info) > 1: - if "min" in info[1] and val < info[1]["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue - if "max" in info[1] and val > info[1]["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, } - errors.append(error) - continue + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id)