loop support added

This commit is contained in:
Dr.Lt.Data 2023-06-16 22:16:09 +09:00
parent 067e2acfd2
commit 173bdd280b
4 changed files with 82 additions and 6 deletions

View File

@ -0,0 +1,59 @@
class LoopControl:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"loop_condition": ("LOOP_CONDITION", ),
"initial_input": ("*", ),
"loopback_input": ("*", ),
},
}
RETURN_TYPES = ("*", )
FUNCTION = "doit"
def doit(s, **kwargs):
if 'loopback_input' not in kwargs or kwargs['loopback_input'] is None:
current = kwargs['initial_input']
else:
current = kwargs['loopback_input']
return (kwargs['loop_condition'].get_next(kwargs['initial_input'], current), )
class CounterCondition:
def __init__(self, value):
self.max = value
self.current = 0
def get_next(self, initial_value, value):
print(f"CounterCondition: {self.current}/{self.max}")
self.current += 1
if self.current == 1:
return initial_value
elif self.current <= self.max:
return value
else:
return None
class LoopCounterCondition:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"count": ("INT", {"default": 1, "min": 0, "max": 9999999, "step": 1}),
"trigger": (["A", "B"], )
},
}
RETURN_TYPES = ("LOOP_CONDITION", )
FUNCTION = "doit"
def doit(s, count, trigger):
return (CounterCondition(count), )
NODE_CLASS_MAPPINGS = {
"LoopControl": LoopControl,
"LoopCounterCondition": LoopCounterCondition,
}

View File

@ -21,10 +21,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
if isinstance(input_data, list):
input_unique_id = input_data[0]
output_index = input_data[1]
if input_unique_id not in outputs or outputs[input_unique_id][input_data[1]] == [None]:
return None
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
if class_def.__name__ != "LoopControl":
if input_unique_id not in outputs or outputs[input_unique_id][input_data[1]] == [None]:
return None
if input_unique_id in outputs and outputs[input_unique_id][input_data[1]] != [None]:
obj = outputs[input_unique_id][output_index]
input_data_all[x] = obj
else:
if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]):
input_data_all[x] = [input_data]
@ -360,7 +363,12 @@ class PromptExecutor:
def validate_inputs(prompt, item, validated):
def validate_inputs(prompt, item, validated, visited=set()):
if item in visited:
return (True, [], item)
else:
visited.add(item)
unique_id = item
if unique_id in validated:
return validated[unique_id]
@ -426,7 +434,7 @@ def validate_inputs(prompt, item, validated):
errors.append(error)
continue
try:
r = validate_inputs(prompt, o_id, validated)
r = validate_inputs(prompt, o_id, validated, visited)
if r[0] is False:
# `r` will be set in `validated[o_id]` already
valid = False

View File

@ -1459,4 +1459,5 @@ def init_custom_nodes():
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py"))
load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_loop.py"))
load_custom_nodes()

View File

@ -87,6 +87,12 @@ def is_incomplete_input_slots(class_def, inputs, outputs):
if len(required_inputs - inputs.keys()) > 0:
return True
if class_def.__name__ == "LoopControl":
inputs = {
'loop_condition': inputs['loop_condition'],
'initial_input': inputs['initial_input'],
}
for x in inputs:
input_data = inputs[x]
@ -209,6 +215,8 @@ def worklist_execute(server, prompt, outputs, extra_data, prompt_id, outputs_ui,
return result # error state
else:
if unique_id in next_nodes:
if class_def.__name__ == "LoopControl" and outputs[unique_id] == [[None]]:
continue
for next_node in next_nodes[unique_id]:
if next_node in to_execute: