mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 14:50:20 +08:00
loop support added
This commit is contained in:
parent
067e2acfd2
commit
173bdd280b
59
comfy_extras/nodes_loop.py
Normal file
59
comfy_extras/nodes_loop.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
class LoopControl:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"loop_condition": ("LOOP_CONDITION", ),
|
||||||
|
"initial_input": ("*", ),
|
||||||
|
"loopback_input": ("*", ),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("*", )
|
||||||
|
FUNCTION = "doit"
|
||||||
|
|
||||||
|
def doit(s, **kwargs):
|
||||||
|
if 'loopback_input' not in kwargs or kwargs['loopback_input'] is None:
|
||||||
|
current = kwargs['initial_input']
|
||||||
|
else:
|
||||||
|
current = kwargs['loopback_input']
|
||||||
|
|
||||||
|
return (kwargs['loop_condition'].get_next(kwargs['initial_input'], current), )
|
||||||
|
|
||||||
|
|
||||||
|
class CounterCondition:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.max = value
|
||||||
|
self.current = 0
|
||||||
|
|
||||||
|
def get_next(self, initial_value, value):
|
||||||
|
print(f"CounterCondition: {self.current}/{self.max}")
|
||||||
|
|
||||||
|
self.current += 1
|
||||||
|
if self.current == 1:
|
||||||
|
return initial_value
|
||||||
|
elif self.current <= self.max:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LoopCounterCondition:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"count": ("INT", {"default": 1, "min": 0, "max": 9999999, "step": 1}),
|
||||||
|
"trigger": (["A", "B"], )
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LOOP_CONDITION", )
|
||||||
|
FUNCTION = "doit"
|
||||||
|
|
||||||
|
def doit(s, count, trigger):
|
||||||
|
return (CounterCondition(count), )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"LoopControl": LoopControl,
|
||||||
|
"LoopCounterCondition": LoopCounterCondition,
|
||||||
|
}
|
||||||
20
execution.py
20
execution.py
@ -21,10 +21,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
|
|||||||
if isinstance(input_data, list):
|
if isinstance(input_data, list):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if input_unique_id not in outputs or outputs[input_unique_id][input_data[1]] == [None]:
|
if class_def.__name__ != "LoopControl":
|
||||||
return None
|
if input_unique_id not in outputs or outputs[input_unique_id][input_data[1]] == [None]:
|
||||||
obj = outputs[input_unique_id][output_index]
|
return None
|
||||||
input_data_all[x] = obj
|
|
||||||
|
if input_unique_id in outputs and outputs[input_unique_id][input_data[1]] != [None]:
|
||||||
|
obj = outputs[input_unique_id][output_index]
|
||||||
|
input_data_all[x] = obj
|
||||||
else:
|
else:
|
||||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
@ -360,7 +363,12 @@ class PromptExecutor:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated):
|
def validate_inputs(prompt, item, validated, visited=set()):
|
||||||
|
if item in visited:
|
||||||
|
return (True, [], item)
|
||||||
|
else:
|
||||||
|
visited.add(item)
|
||||||
|
|
||||||
unique_id = item
|
unique_id = item
|
||||||
if unique_id in validated:
|
if unique_id in validated:
|
||||||
return validated[unique_id]
|
return validated[unique_id]
|
||||||
@ -426,7 +434,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
r = validate_inputs(prompt, o_id, validated)
|
r = validate_inputs(prompt, o_id, validated, visited)
|
||||||
if r[0] is False:
|
if r[0] is False:
|
||||||
# `r` will be set in `validated[o_id]` already
|
# `r` will be set in `validated[o_id]` already
|
||||||
valid = False
|
valid = False
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -1459,4 +1459,5 @@ def init_custom_nodes():
|
|||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
|
||||||
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
|
||||||
|
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_loop.py"))
|
||||||
load_custom_nodes()
|
load_custom_nodes()
|
||||||
|
|||||||
@ -87,6 +87,12 @@ def is_incomplete_input_slots(class_def, inputs, outputs):
|
|||||||
if len(required_inputs - inputs.keys()) > 0:
|
if len(required_inputs - inputs.keys()) > 0:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if class_def.__name__ == "LoopControl":
|
||||||
|
inputs = {
|
||||||
|
'loop_condition': inputs['loop_condition'],
|
||||||
|
'initial_input': inputs['initial_input'],
|
||||||
|
}
|
||||||
|
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
|
|
||||||
@ -209,6 +215,8 @@ def worklist_execute(server, prompt, outputs, extra_data, prompt_id, outputs_ui,
|
|||||||
return result # error state
|
return result # error state
|
||||||
else:
|
else:
|
||||||
if unique_id in next_nodes:
|
if unique_id in next_nodes:
|
||||||
|
if class_def.__name__ == "LoopControl" and outputs[unique_id] == [[None]]:
|
||||||
|
continue
|
||||||
|
|
||||||
for next_node in next_nodes[unique_id]:
|
for next_node in next_nodes[unique_id]:
|
||||||
if next_node in to_execute:
|
if next_node in to_execute:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user