mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-09 21:12:36 +08:00
allow nodes to map over lists
This commit is contained in:
parent
aa57136dae
commit
87faf9a4a4
72
execution.py
72
execution.py
@ -26,20 +26,74 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = input_data
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = prompt
|
||||
input_data_all[x] = [prompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
if "extra_pnginfo" in extra_data:
|
||||
input_data_all[x] = extra_data['extra_pnginfo']
|
||||
input_data_all[x] = [extra_data['extra_pnginfo']]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = unique_id
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
|
||||
def get_output_data(obj, input_data_all):
|
||||
|
||||
# check if node wants the lists
|
||||
intput_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
intput_is_list = obj.INPUT_IS_LIST
|
||||
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
results = []
|
||||
uis = []
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k,v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
|
||||
def append_return(input):
|
||||
nodes.before_node_execution()
|
||||
node_result = getattr(obj, obj.FUNCTION)(**input)
|
||||
if isinstance(node_result, dict):
|
||||
if 'ui' in node_result:
|
||||
uis.append(node_result['ui'])
|
||||
if 'result' in node_result:
|
||||
results.append(node_result['result'])
|
||||
else:
|
||||
results.append(node_result)
|
||||
|
||||
if intput_is_list:
|
||||
append_return(**input_data_all)
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
append_return(slice_dict(input_data_all, i))
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
|
||||
def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
@ -63,13 +117,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
server.send_sync("executing", { "node": unique_id }, server.client_id)
|
||||
obj = class_def()
|
||||
|
||||
nodes.before_node_execution()
|
||||
outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all)
|
||||
if "ui" in outputs[unique_id]:
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
|
||||
if "result" in outputs[unique_id]:
|
||||
outputs[unique_id] = outputs[unique_id]["result"]
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id)
|
||||
executed.add(unique_id)
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user