find loops and preserve extra_info in errors

This commit is contained in:
Chris 2023-10-02 13:03:39 +11:00
parent 2ef459b1d4
commit 7499011289

View File

@ -384,7 +384,7 @@ class PromptExecutor:
def validate_inputs(prompt, item, validated):
def validate_inputs(prompt, item, validated, stack=[]):
unique_id = item
if unique_id in validated:
return validated[unique_id]
@ -399,6 +399,20 @@ def validate_inputs(prompt, item, validated):
errors = []
valid = True
if unique_id in stack:
error = {
"type": "infinite_loop",
"message": "loop detected in workflow validation",
"details": f"detected at {unique_id}",
"extra_info": {"stack": f"{stack}"},
}
errors.append(error)
ret = (False, errors, unique_id)
validated[unique_id] = ret
# don't continue, because we're already here further up the stack
return ret
stack.append(unique_id)
for x in required_inputs:
if x not in inputs:
error = {
@ -450,7 +464,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
r = validate_inputs(prompt, o_id, validated, stack)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False
@ -582,8 +596,10 @@ def validate_inputs(prompt, item, validated):
else:
ret = (True, [], unique_id)
validated[unique_id] = ret
return ret
# if we had a loop, unique_id will have been marked invalid further down the tree
if unique_id not in validated:
validated[unique_id] = ret
return validated[unique_id]
def full_type_name(klass):
module = klass.__module__
@ -615,7 +631,7 @@ def validate_prompt(prompt):
valid = False
reasons = []
try:
m = validate_inputs(prompt, o, validated)
m = validate_inputs(prompt, o, validated, [])
valid = m[0]
reasons = m[1]
except Exception as ex:
@ -664,16 +680,20 @@ def validate_prompt(prompt):
if len(good_outputs) == 0:
errors_list = []
extra_info = {}
for o, errors in errors:
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
if errors:
extra_info[o] = []
for error in errors:
errors_list.append(f"{error['message']}: {error['details']}")
extra_info[o].append(error.get('extra_info',""))
errors_list = "\n".join(errors_list)
error = {
"type": "prompt_outputs_failed_validation",
"message": "Prompt outputs failed validation",
"details": errors_list,
"extra_info": {}
"extra_info": extra_info,
}
return (False, error, list(good_outputs), node_errors)