diff --git a/execution.py b/execution.py index 5f5d6c738..132d346bf 100644 --- a/execution.py +++ b/execution.py @@ -384,7 +384,7 @@ class PromptExecutor: -def validate_inputs(prompt, item, validated): +def validate_inputs(prompt, item, validated, stack=[]): unique_id = item if unique_id in validated: return validated[unique_id] @@ -399,6 +399,20 @@ def validate_inputs(prompt, item, validated): errors = [] valid = True + if unique_id in stack: + error = { + "type": "infinite_loop", + "message": "loop detected in workflow validation", + "details": f"detected at {unique_id}", + "extra_info": {"stack": f"{stack}"}, + } + errors.append(error) + ret = (False, errors, unique_id) + validated[unique_id] = ret + # don't continue, because we're already here further up the stack + return ret + stack.append(unique_id) + for x in required_inputs: if x not in inputs: error = { @@ -450,7 +464,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = validate_inputs(prompt, o_id, validated, stack) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -582,8 +596,10 @@ def validate_inputs(prompt, item, validated): else: ret = (True, [], unique_id) - validated[unique_id] = ret - return ret +# if we had a loop, unique_id will have been marked invalid further down the tree + if unique_id not in validated: + validated[unique_id] = ret + return validated[unique_id] def full_type_name(klass): module = klass.__module__ @@ -615,7 +631,7 @@ def validate_prompt(prompt): valid = False reasons = [] try: - m = validate_inputs(prompt, o, validated) + m = validate_inputs(prompt, o, validated, []) valid = m[0] reasons = m[1] except Exception as ex: @@ -664,16 +680,20 @@ def validate_prompt(prompt): if len(good_outputs) == 0: errors_list = [] + extra_info = {} for o, errors in errors: - for error in errors: - errors_list.append(f"{error['message']}: {error['details']}") + if errors: + extra_info[o] = [] + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + extra_info[o].append(error.get('extra_info',"")) errors_list = "\n".join(errors_list) error = { "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, - "extra_info": {} + "extra_info": extra_info, } return (False, error, list(good_outputs), node_errors)