Validate optional inputs

Previously, nodes that were connected to output nodes only via optional
inputs were never validated, but were still run. This would cause a
Python error and stop execution wherever we happen to be in the
execution list. This bug exists on master, but may be more noticeable in
this branch because execution order is non-deterministic.

Other minor change this commit introduces: `raw_link` can be specified
as an option on an input to receive the raw link (in the standard form
of [node_id, output_index]) rather than a resolved value.
This commit is contained in:
Jacob Segal 2023-07-19 23:19:58 -07:00
parent d8fdbb5168
commit 83eb7e5057

View File

@ -20,6 +20,28 @@ class ExecutionResult(Enum):
FAILURE = 1
SLEEPING = 2
def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
if len(input_info) > 1:
extra_info = input_info[1]
else:
extra_info = {}
return input_type, input_category, extra_info
# ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
# it can still be returned to the graph after having further dependencies added.
class ExecutionList:
@ -34,25 +56,7 @@ class ExecutionList:
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
valid_inputs = class_def.INPUT_TYPES()
input_info = None
input_category = None
if input_name in valid_inputs["required"]:
input_category = "required"
input_info = valid_inputs["required"][input_name]
elif input_name in valid_inputs["optional"]:
input_category = "optional"
input_info = valid_inputs["optional"][input_name]
elif input_name in valid_inputs["hidden"]:
input_category = "hidden"
input_info = valid_inputs["hidden"][input_name]
if input_info is None:
return None, None, None
input_type = input_info[0]
extra_info = None
if len(input_info) > 1:
extra_info = input_info[1]
return input_type, input_category, extra_info
return get_input_info(class_def, input_name)
def make_input_strong_link(self, to_node_id, to_input):
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
@ -87,7 +91,7 @@ class ExecutionList:
if isinstance(value, list):
from_node_id, from_socket = value
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
if input_info is None or "lazy" not in input_info or not input_info["lazy"]:
if "lazy" not in input_info or not input_info["lazy"]:
self.add_strong_link(from_node_id, from_socket, unique_id)
def stage_node_execution(self):
@ -155,16 +159,16 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, dynpromp
input_data_all = {}
for x in inputs:
input_data = inputs[x]
if isinstance(input_data, list):
input_type, input_category, input_info = get_input_info(class_def, x)
if isinstance(input_data, list) and not input_info.get("raw_link", False):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs:
continue # This might be a lazily-evaluated input
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
elif input_category is not None:
input_data_all[x] = [input_data]
if "hidden" in valid_inputs:
h = valid_inputs["hidden"]
@ -571,27 +575,28 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
required_inputs = class_inputs['required']
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = []
valid = True
for x in required_inputs:
for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x)
if x not in inputs:
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
if input_category == "required":
error = {
"type": "required_input_missing",
"message": "Required input is missing",
"details": f"{x}",
"extra_info": {
"input_name": x
}
}
}
errors.append(error)
errors.append(error)
continue
val = inputs[x]
info = required_inputs[x]
type_input = info[0]
info = (type_input, extra_info)
if isinstance(val, list):
if len(val) != 2:
error = {
@ -677,33 +682,32 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
if len(info) > 1:
if "min" in info[1] and val < info[1]["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, info[1]["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
if "min" in extra_info and val < extra_info["min"]:
error = {
"type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
errors.append(error)
continue
if "max" in info[1] and val > info[1]["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, info[1]["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "max" in extra_info and val > extra_info["max"]:
error = {
"type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
errors.append(error)
continue
}
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)