Validate optional inputs

Previously, nodes that were connected to output nodes only via optional
inputs were never validated, but were still run. This would cause a
Python error and stop execution wherever we happen to be in the
execution list. This bug exists on master, but may be more noticeable in
this branch because execution order is non-deterministic.

Other minor change this commit introduces: `raw_link` can be specified
as an option on an input to receive the raw link (in the standard form
of [node_id, output_index]) rather than a resolved value.
This commit is contained in:
Jacob Segal 2023-07-19 23:19:58 -07:00
parent d8fdbb5168
commit 83eb7e5057

View File

@ -20,6 +20,28 @@ class ExecutionResult(Enum):
FAILURE = 1 FAILURE = 1
SLEEPING = 2 SLEEPING = 2
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, # ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added. # it can still be returned to the graph after having further dependencies added.
class ExecutionList: class ExecutionList:
@ -34,25 +56,7 @@ class ExecutionList:
def get_input_info(self, unique_id, input_name): def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"] class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
valid_inputs = class_def.INPUT_TYPES() return get_input_info(class_def, input_name)
input_info = None
input_category = None
if input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
extra_info = None
if len(input_info) > 1:
extra_info = input_info[1]
return input_type, input_category, extra_info
def make_input_strong_link(self, to_node_id, to_input): def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"] inputs = self.dynprompt.get_node(to_node_id)["inputs"]
@ -87,7 +91,7 @@ class ExecutionList:
if isinstance(value, list): if isinstance(value, list):
from_node_id, from_socket = value from_node_id, from_socket = value
input_type, input_category, input_info = self.get_input_info(unique_id, input_name) input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
if input_info is None or "lazy" not in input_info or not input_info["lazy"]: if "lazy" not in input_info or not input_info["lazy"]:
self.add_strong_link(from_node_id, from_socket, unique_id) self.add_strong_link(from_node_id, from_socket, unique_id)
def stage_node_execution(self): def stage_node_execution(self):
@ -155,16 +159,16 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
input_data_all = {} input_data_all = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
if isinstance(input_data, list): input_type, input_category, input_info = get_input_info(class_def, x)
if isinstance(input_data, list) and not input_info.get("raw_link", False):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
continue # This might be a lazily-evaluated input continue # This might be a lazily-evaluated input
obj = outputs[input_unique_id][output_index] obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj input_data_all[x] = obj
else: elif input_category is not None:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): input_data_all[x] = [input_data]
input_data_all[x] = [input_data]
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
h = valid_inputs["hidden"] h = valid_inputs["hidden"]
@ -571,27 +575,28 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES() class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required'] valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = [] errors = []
valid = True valid = True
for x in required_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
if x not in inputs: if x not in inputs:
error = { if input_category == "required":
"type": "required_input_missing", error = {
"message": "Required input is missing", "type": "required_input_missing",
"details": f"{x}", "message": "Required input is missing",
"extra_info": { "details": f"{x}",
"input_name": x "extra_info": {
"input_name": x
}
} }
} errors.append(error)
errors.append(error)
continue continue
val = inputs[x] val = inputs[x]
info = required_inputs[x] info = (type_input, extra_info)
type_input = info[0]
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
error = { error = {
@ -677,33 +682,32 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(info) > 1: if "min" in extra_info and val < extra_info["min"]:
if "min" in info[1] and val < info[1]["min"]: error = {
error = { "type": "value_smaller_than_min",
"type": "value_smaller_than_min", "message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]), "details": f"{x}",
"details": f"{x}", "extra_info": {
"extra_info": { "input_name": x,
"input_name": x, "input_config": info,
"input_config": info, "received_value": val,
"received_value": val,
}
} }
errors.append(error) }
continue errors.append(error)
if "max" in info[1] and val > info[1]["max"]: continue
error = { if "max" in extra_info and val > extra_info["max"]:
"type": "value_bigger_than_max", error = {
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]), "type": "value_bigger_than_max",
"details": f"{x}", "message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"extra_info": { "details": f"{x}",
"input_name": x, "extra_info": {
"input_config": info, "input_name": x,
"received_value": val, "input_config": info,
} "received_value": val,
} }
errors.append(error) }
continue errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id) input_data_all = get_input_data(inputs, obj_class, unique_id)