diff --git a/execution.py b/execution.py index dc2b67554..5bf62d678 100644 --- a/execution.py +++ b/execution.py @@ -40,17 +40,14 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = [unique_id] return input_data_all -def get_output_data(obj, input_data_all): - +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): # check if node wants the lists intput_is_list = False if hasattr(obj, "INPUT_IS_LIST"): intput_is_list = obj.INPUT_IS_LIST max_len_input = max([len(x) for x in input_data_all.values()]) - results = [] - uis = [] - + # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): d_new = dict() @@ -58,22 +55,32 @@ def get_output_data(obj, input_data_all): d_new[k] = v[i if len(v) > i else -1] return d_new - def append_return(input): - nodes.before_node_execution() - node_result = getattr(obj, obj.FUNCTION)(**input) - if isinstance(node_result, dict): - if 'ui' in node_result: - uis.append(node_result['ui']) - if 'result' in node_result: - results.append(node_result['result']) - else: - results.append(node_result) - + results = [] if intput_is_list: - append_return(**input_data_all) + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) else: for i in range(max_len_input): - append_return(slice_dict(input_data_all, i)) + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) output = [] if len(results) > 0: @@ -157,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -312,9 +320,11 @@ def validate_inputs(prompt, item): if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): if val not in type_input: