mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
make work with IS_CHANGED and VALIDATE_INPUTS
This commit is contained in:
parent
805e043bce
commit
896b654296
52
execution.py
52
execution.py
@ -40,16 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all
|
||||||
|
|
||||||
def get_output_data(obj, input_data_all):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||||
|
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
intput_is_list = False
|
intput_is_list = False
|
||||||
if hasattr(obj, "INPUT_IS_LIST"):
|
if hasattr(obj, "INPUT_IS_LIST"):
|
||||||
intput_is_list = obj.INPUT_IS_LIST
|
intput_is_list = obj.INPUT_IS_LIST
|
||||||
|
|
||||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||||
results = []
|
|
||||||
uis = []
|
|
||||||
|
|
||||||
# get a slice of inputs, repeat last input when list isn't long enough
|
# get a slice of inputs, repeat last input when list isn't long enough
|
||||||
def slice_dict(d, i):
|
def slice_dict(d, i):
|
||||||
@ -58,22 +55,32 @@ def get_output_data(obj, input_data_all):
|
|||||||
d_new[k] = v[i if len(v) > i else -1]
|
d_new[k] = v[i if len(v) > i else -1]
|
||||||
return d_new
|
return d_new
|
||||||
|
|
||||||
def append_return(input):
|
results = []
|
||||||
nodes.before_node_execution()
|
|
||||||
node_result = getattr(obj, obj.FUNCTION)(**input)
|
|
||||||
if isinstance(node_result, dict):
|
|
||||||
if 'ui' in node_result:
|
|
||||||
uis.append(node_result['ui'])
|
|
||||||
if 'result' in node_result:
|
|
||||||
results.append(node_result['result'])
|
|
||||||
else:
|
|
||||||
results.append(node_result)
|
|
||||||
|
|
||||||
if intput_is_list:
|
if intput_is_list:
|
||||||
append_return(**input_data_all)
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)(**input_data_all))
|
||||||
else:
|
else:
|
||||||
for i in range(max_len_input):
|
for i in range(max_len_input):
|
||||||
append_return(slice_dict(input_data_all, i))
|
if allow_interrupt:
|
||||||
|
nodes.before_node_execution()
|
||||||
|
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_output_data(obj, input_data_all):
|
||||||
|
|
||||||
|
results = []
|
||||||
|
uis = []
|
||||||
|
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||||
|
|
||||||
|
for r in return_values:
|
||||||
|
if isinstance(r, dict):
|
||||||
|
if 'ui' in r:
|
||||||
|
uis.append(r['ui'])
|
||||||
|
if 'result' in r:
|
||||||
|
results.append(r['result'])
|
||||||
|
else:
|
||||||
|
results.append(r)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
if len(results) > 0:
|
if len(results) > 0:
|
||||||
@ -157,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
|||||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||||
if input_data_all is not None:
|
if input_data_all is not None:
|
||||||
try:
|
try:
|
||||||
is_changed = class_def.IS_CHANGED(**input_data_all)
|
#is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||||
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
prompt[unique_id]['is_changed'] = is_changed
|
prompt[unique_id]['is_changed'] = is_changed
|
||||||
except:
|
except:
|
||||||
to_delete = True
|
to_delete = True
|
||||||
@ -312,9 +320,11 @@ def validate_inputs(prompt, item):
|
|||||||
|
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
#ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||||
if ret != True:
|
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||||
return (False, "{}, {}".format(class_type, ret))
|
for r in ret:
|
||||||
|
if r != True:
|
||||||
|
return (False, "{}, {}".format(class_type, r))
|
||||||
else:
|
else:
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user