From 173bdd280b5241a6e0aec052f42f59bb7721a7d7 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Fri, 16 Jun 2023 22:16:09 +0900 Subject: [PATCH] loop support added --- comfy_extras/nodes_loop.py | 59 ++++++++++++++++++++++++++++++++++++++ execution.py | 20 +++++++++---- nodes.py | 1 + worklist_execution.py | 8 ++++++ 4 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 comfy_extras/nodes_loop.py diff --git a/comfy_extras/nodes_loop.py b/comfy_extras/nodes_loop.py new file mode 100644 index 000000000..f4b37ea17 --- /dev/null +++ b/comfy_extras/nodes_loop.py @@ -0,0 +1,59 @@ +class LoopControl: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "loop_condition": ("LOOP_CONDITION", ), + "initial_input": ("*", ), + "loopback_input": ("*", ), + }, + } + + RETURN_TYPES = ("*", ) + FUNCTION = "doit" + + def doit(s, **kwargs): + if 'loopback_input' not in kwargs or kwargs['loopback_input'] is None: + current = kwargs['initial_input'] + else: + current = kwargs['loopback_input'] + + return (kwargs['loop_condition'].get_next(kwargs['initial_input'], current), ) + + +class CounterCondition: + def __init__(self, value): + self.max = value + self.current = 0 + + def get_next(self, initial_value, value): + print(f"CounterCondition: {self.current}/{self.max}") + + self.current += 1 + if self.current == 1: + return initial_value + elif self.current <= self.max: + return value + else: + return None + + +class LoopCounterCondition: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "count": ("INT", {"default": 1, "min": 0, "max": 9999999, "step": 1}), + "trigger": (["A", "B"], ) + }, + } + + RETURN_TYPES = ("LOOP_CONDITION", ) + FUNCTION = "doit" + + def doit(s, count, trigger): + return (CounterCondition(count), ) + + +NODE_CLASS_MAPPINGS = { + "LoopControl": LoopControl, + "LoopCounterCondition": LoopCounterCondition, +} diff --git a/execution.py b/execution.py index 145685db7..cab38b470 100644 --- a/execution.py +++ b/execution.py @@ -21,10 +21,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da if isinstance(input_data, list): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs or outputs[input_unique_id][input_data[1]] == [None]: - return None - obj = outputs[input_unique_id][output_index] - input_data_all[x] = obj + if class_def.__name__ != "LoopControl": + if input_unique_id not in outputs or outputs[input_unique_id][input_data[1]] == [None]: + return None + + if input_unique_id in outputs and outputs[input_unique_id][input_data[1]] != [None]: + obj = outputs[input_unique_id][output_index] + input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): input_data_all[x] = [input_data] @@ -360,7 +363,12 @@ class PromptExecutor: -def validate_inputs(prompt, item, validated): +def validate_inputs(prompt, item, validated, visited=set()): + if item in visited: + return (True, [], item) + else: + visited.add(item) + unique_id = item if unique_id in validated: return validated[unique_id] @@ -426,7 +434,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue try: - r = validate_inputs(prompt, o_id, validated) + r = validate_inputs(prompt, o_id, validated, visited) if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False diff --git a/nodes.py b/nodes.py index 658e32dad..39d692eec 100644 --- a/nodes.py +++ b/nodes.py @@ -1459,4 +1459,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_loop.py")) load_custom_nodes() diff --git a/worklist_execution.py b/worklist_execution.py index 3d4b4937b..58cc70a86 100644 --- a/worklist_execution.py +++ b/worklist_execution.py @@ -87,6 +87,12 @@ def is_incomplete_input_slots(class_def, inputs, outputs): if len(required_inputs - inputs.keys()) > 0: return True + if class_def.__name__ == "LoopControl": + inputs = { + 'loop_condition': inputs['loop_condition'], + 'initial_input': inputs['initial_input'], + } + for x in inputs: input_data = inputs[x] @@ -209,6 +215,8 @@ def worklist_execute(server, prompt, outputs, extra_data, prompt_id, outputs_ui, return result # error state else: if unique_id in next_nodes: + if class_def.__name__ == "LoopControl" and outputs[unique_id] == [[None]]: + continue for next_node in next_nodes[unique_id]: if next_node in to_execute: