From a51490b68b633364072c4605bf5c8687b30655e5 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 27 Mar 2026 01:52:08 +0200 Subject: [PATCH] Add looping nodes --- comfy_extras/nodes_looping.py | 767 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 768 insertions(+) create mode 100644 comfy_extras/nodes_looping.py diff --git a/comfy_extras/nodes_looping.py b/comfy_extras/nodes_looping.py new file mode 100644 index 000000000..3d4d2551d --- /dev/null +++ b/comfy_extras/nodes_looping.py @@ -0,0 +1,767 @@ +import torch +from comfy_execution.graph_utils import GraphBuilder, is_link +from comfy_api.latest import ComfyExtension, io +from comfy.nested_tensor import NestedTensor +import comfy.utils + + +NUM_FLOW_SOCKETS = 5 + +def _is_nested(x): + return isinstance(x, NestedTensor) + +def _temporal_frame_count(samples): + """Count temporal frames from a samples tensor or NestedTensor.""" + if _is_nested(samples): + return samples.tensors[0].shape[2] # count from first sub-tensor (video) + if samples.ndim == 5: + return samples.shape[2] # (B,C,T,H,W) + return samples.shape[0] # (B,C,H,W) — batch count + +def _accum_count(accum): + """Count items in an accumulation, handling tensors (Image/Mask) and dicts (Latent/Video Latent).""" + if not isinstance(accum, dict) or "accum" not in accum: + return 0 + total = 0 + for item in accum["accum"]: + if isinstance(item, dict): + total += _temporal_frame_count(item["samples"]) + else: + total += item.shape[0] # IMAGE/MASK: count batch items + return total + + +class TensorLoopOpen(io.ComfyNode): + """ + Opens a loop that collects outputs. Supports two modes: + - iterations: runs a fixed number of iterations + - total_frames: runs until the target number of frames is accumulated + Wire flow_control → TensorLoopClose, use `previous_value` as input to your + generation, and connect the generated output → TensorLoopClose.processed. + Supports IMAGE, MASK, and LATENT types. + """ + MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent]) + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="TensorLoopOpen", + display_name="Tensor Loop Open", + category="looping/accumulation", + inputs=[ + io.DynamicCombo.Input("mode", tooltip="Loop termination mode.", options=[ + io.DynamicCombo.Option("iterations", [ + io.Int.Input("iterations", default=4, min=0, + tooltip="Number of loop iterations. Set to 0 to skip looping and pass through initial_value."), + ]), + io.DynamicCombo.Option("total_frames", [ + io.Int.Input("total_frames", default=100, min=1, + tooltip="Target number of output frames. The loop continues until this many frames are accumulated."), + ]), + ]), + io.MatchType.Input("initial_value", template=cls.MATCHTYPE, optional=True, + tooltip="Optional value to use as `previous_value` on the first iteration."), + ], + outputs=[ + io.FlowControl.Output("flow_control"), + io.MatchType.Output(cls.MATCHTYPE, id="previous_value", + tooltip="The value from the previous iteration (or initial_value on first pass)."), + io.Int.Output("accumulated_count", tooltip="Number of items collected so far (0 on first iteration)."), + io.Int.Output("current_iteration", tooltip="Current iteration index (1-based)."), + ], + hidden=[io.Hidden.unique_id], + accept_all_inputs=True, + ) + + + @classmethod + def execute(cls, mode: dict, initial_value=None, **kwargs) -> io.NodeOutput: + unique_id = cls.hidden.unique_id + state = kwargs.get("initial_value0") # packed state dict or None on first pass + if state is not None: + count = state["count"] + total_frames_val = state["total_frames"] + remaining = state["remaining"] + accum = state["accum"] + previous_value = state["previous_value"] + open_node_id = state["open_node_id"] + else: + selected_mode = mode.get("mode", "iterations") + count = mode.get("iterations", 4) if selected_mode == "iterations" else 0 + total_frames_val = mode.get("total_frames", 100) if selected_mode == "total_frames" else 0 + remaining = count + accum = None + previous_value = initial_value + open_node_id = unique_id + + accumulated_count = _accum_count(accum) + # In total_frames mode, count=0 and remaining goes negative each iteration. + # The math still produces correct 1-based iteration: 0-0+1=1, 0-(-1)+1=2, etc. + current_iteration = count - remaining + 1 + if total_frames_val > 0: + comfy.utils.ProgressBar(total_frames_val, node_id=open_node_id).update_absolute(accumulated_count) + elif count > 0: + comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(count - remaining) + loop_state = {"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames_val} + return io.NodeOutput(loop_state, previous_value, accumulated_count, current_iteration) + + +class TensorLoopClose(io.ComfyNode): + """ + Closes the loop started by TensorLoopOpen. + Connect: + - flow_control from TensorLoopOpen + - processed: the output generated this iteration + Supports IMAGE, MASK, and LATENT types. + """ + MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent]) + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="TensorLoopClose", + display_name="Tensor Loop Close", + category="looping/accumulation", + inputs=[ + io.FlowControl.Input("flow_control", raw_link=True), + io.MatchType.Input("processed", template=cls.MATCHTYPE, raw_link=True, tooltip="Output generated this iteration."), + io.Boolean.Input("accumulate", default=True, + tooltip="When enabled, collects all iterations into a batch. When disabled, only outputs the final iteration's result."), + io.DynamicCombo.Input("overlap", tooltip="Remove or blend duplicate frames where consecutive iterations overlap.", options=[ + io.DynamicCombo.Option("disabled", []), + io.DynamicCombo.Option("start", [ + io.Int.Input("overlap_frames", default=8, min=1, + tooltip="Number of frames to trim. Use when the model re-generates context frames at the start of its output (most common for video continuation)."), + ]), + io.DynamicCombo.Option("end", [ + io.Int.Input("overlap_frames", default=8, min=1, + tooltip="Number of frames to trim. Use when the model generates look-ahead frames at the end of its output."), + ]), + io.DynamicCombo.Option("fade_linear", [ + io.Int.Input("overlap_frames", default=8, min=1, + tooltip="Number of frames to crossfade with a linear blend between consecutive iterations."), + ]), + io.DynamicCombo.Option("fade_smooth", [ + io.Int.Input("overlap_frames", default=8, min=1, + tooltip="Number of frames to crossfade with a smoothstep (ease in/out) blend between consecutive iterations."), + ]), + ]), + io.Boolean.Input("stop", optional=True, default=False, raw_link=True, force_input=True, + tooltip="Optional early stop signal from inside the loop body. When True, the loop stops after the current iteration regardless of remaining iterations or total_frames target."), + ], + outputs=[ + io.MatchType.Output(cls.MATCHTYPE, id="output", tooltip="Accumulated batch or final iteration result, depending on 'accumulate' setting."), + ], + enable_expand=True, + ) + + + @classmethod + def execute(cls, flow_control, processed, accumulate=True, overlap=None, stop=False) -> io.NodeOutput: + graph = GraphBuilder() + open_id = flow_control[0] + unpack = graph.node("_ImageAccumStateUnpack", loop_state=[open_id, 0]) + # unpack: 0=remaining, 1=accum, 2=previous_value, 3=accumulated_count, 4=count, 5=open_node_id, 6=total_frames + sub = graph.node("_IntOperations", operation="subtract", a=unpack.out(0), b=1) + + overlap_mode = "disabled" + overlap_frames = 0 + if isinstance(overlap, dict): + overlap_mode = overlap.get("overlap", "disabled") + overlap_frames = overlap.get("overlap_frames", 0) + + accum_out = unpack.out(1) + if accumulate: + to_accum = processed + if overlap_frames > 0 and overlap_mode != "disabled": + is_first = graph.node("_IntOperations", a=unpack.out(3), b=0, operation="==") + trimmed_start = graph.node("_BatchOps", batch=processed, operation="trim_start", amount=overlap_frames).out(0) + + if overlap_mode == "start": + to_accum = trimmed_start + elif overlap_mode == "end": + trimmed_end = graph.node("_BatchOps", batch=processed, operation="trim_end", amount=overlap_frames).out(0) + trimmed_both = graph.node("_BatchOps", batch=trimmed_end, operation="trim_start", amount=overlap_frames).out(0) + to_accum = graph.node("_ConditionalSelect", + condition=is_first.out(1), + value_if_true=trimmed_both, + value_if_false=trimmed_end, + ).out(0) + else: + # Fade: trim start on iter 1 only, keep full on subsequent for post-loop blend + to_accum = graph.node("_ConditionalSelect", + condition=is_first.out(1), + value_if_true=trimmed_start, + value_if_false=processed, + ).out(0) + + accum_out = graph.node("_AccumulateNode", to_add=to_accum, accumulation=accum_out).out(0) + + # Disable total_frames when not accumulating to avoid infinite loops + pack = graph.node("_ImageAccumStatePack", + remaining=sub.out(0), + accum=accum_out, + previous_value=processed, + count=unpack.out(4), + open_node_id=unpack.out(5), + total_frames=unpack.out(6) if accumulate else 0, + prev_accumulated_count=unpack.out(3), + ) + + # Optional early stop from loop body + if is_link(stop): + condition = graph.node("_ConditionalSelect", + condition=stop, + value_if_true=False, + value_if_false=pack.out(1), + ).out(0) + else: + condition = pack.out(1) + + while_close = graph.node( + "_WhileLoopClose", + flow_control=flow_control, + condition=condition, + initial_value0=pack.out(0), + ) + + final_unpack = graph.node("_ImageAccumStateUnpack", loop_state=while_close.out(0)) + if accumulate: + if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0: + result = graph.node("_AccumulationToImageBatch", + accumulation=final_unpack.out(1), + overlap_frames=overlap_frames, + overlap_mode=overlap_mode, + ).out(0) + else: + result = graph.node("_AccumulationToImageBatch", accumulation=final_unpack.out(1)).out(0) + + # End mode: re-append the tail that was trimmed from the last iteration + if overlap_mode == "end" and overlap_frames > 0: + tail = graph.node("_BatchOps", batch=final_unpack.out(2), operation="keep_end", amount=overlap_frames).out(0) + result = graph.node("_BatchOps", batch=result, operation="concat", batch_b=tail).out(0) + + result = graph.node("_BatchOps", batch=result, operation="max_count", amount=final_unpack.out(6)).out(0) + else: + result = final_unpack.out(2) + + # Bypass: when iterations==0 AND total_frames==0, return initial_value directly + count_is_zero = graph.node("_IntOperations", a=unpack.out(4), b=0, operation="==") + tf_is_zero = graph.node("_IntOperations", a=unpack.out(6), b=0, operation="==") + both_zero = graph.node("_IntOperations", a=count_is_zero.out(0), b=tf_is_zero.out(0), operation="multiply") + result = graph.node("_ConditionalSelect", + condition=both_zero.out(1), + value_if_true=unpack.out(2), + value_if_false=result, + ).out(0) + + return io.NodeOutput(result, expand=graph.finalize()) + +# Internal helper nodes — dev only, hidden from the node menu. + +class _AccumulateNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_AccumulateNode", + display_name="Accumulate", + category="looping/accumulation", + is_dev_only=True, + inputs=[ + io.AnyType.Input("to_add"), + io.Accumulation.Input("accumulation", optional=True), + ], + outputs=[ + io.Accumulation.Output(), + ], + ) + + @classmethod + def execute(cls, to_add, accumulation=None) -> io.NodeOutput: + if accumulation is None: + value = [to_add] + else: + value = accumulation["accum"] + [to_add] + return io.NodeOutput({"accum": value}) + +class _WhileLoopOpen(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_WhileLoopOpen", + display_name="While Loop Open", + category="looping", + is_dev_only=True, + inputs=[ + io.Boolean.Input("condition", default=True), + *[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)], + ], + outputs=[ + io.FlowControl.Output("flow_control", display_name="FLOW_CONTROL"), + *[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)], + ], + accept_all_inputs=True, + ) + + @classmethod + def execute(cls, condition: bool, **kwargs) -> io.NodeOutput: + values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)] + return io.NodeOutput("stub", *values) + +class _WhileLoopClose(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_WhileLoopClose", + display_name="While Loop Close", + category="looping", + is_dev_only=True, + inputs=[ + io.FlowControl.Input("flow_control", raw_link=True), + io.Boolean.Input("condition", force_input=True), + *[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)], + ], + outputs=[ + *[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)], + ], + hidden=[io.Hidden.dynprompt, io.Hidden.unique_id], + enable_expand=True, + accept_all_inputs=True, + ) + + @staticmethod + def _explore_dependencies(node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if is_link(v): + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + _WhileLoopClose._explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + @staticmethod + def _collect_contained(node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + _WhileLoopClose._collect_contained(child_id, upstream, contained) + + @classmethod + def execute(cls, flow_control, condition: bool, **kwargs) -> io.NodeOutput: + dynprompt = cls.hidden.dynprompt + unique_id = cls.hidden.unique_id + values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)] + + if not condition: # Done with the loop — return current values + return io.NodeOutput(*values) + + # Build the graph expansion for the next loop iteration + upstream = {} + cls._explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = flow_control[0] + cls._collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + # Use "Recurse" for this node's clone to avoid exponential name growth + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + for k, v in original_node["inputs"].items(): + if is_link(v) and v[0] in contained: + parent = graph.lookup_node(v[0]) + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + + new_open = graph.lookup_node(open_node) + for i in range(NUM_FLOW_SOCKETS): + new_open.set_input(f"initial_value{i}", values[i]) + + my_clone = graph.lookup_node("Recurse") + result = tuple(my_clone.out(x) for x in range(NUM_FLOW_SOCKETS)) + return io.NodeOutput(*result, expand=graph.finalize()) + + +class _IntOperations(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_IntOperations", + display_name="Int Operations", + category="looping/logic", + is_dev_only=True, + inputs=[ + io.Int.Input("a", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1), + io.Int.Input("b", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1), + io.Combo.Input("operation", options=[ + "add", "subtract", "multiply", "divide", "modulo", "power", + "==", "!=", "<", ">", "<=", ">=", + ]), + ], + outputs=[ + io.Int.Output(), + io.Boolean.Output(), + ], + ) + + OPS = { + "add": lambda a, b: a + b, "subtract": lambda a, b: a - b, + "multiply": lambda a, b: a * b, "divide": lambda a, b: a // b if b else 0, + "modulo": lambda a, b: a % b if b else 0, "power": lambda a, b: a ** b, + "==": lambda a, b: a == b, "!=": lambda a, b: a != b, + "<": lambda a, b: a < b, ">": lambda a, b: a > b, + "<=": lambda a, b: a <= b, ">=": lambda a, b: a >= b, + } + + @classmethod + def execute(cls, a: int, b: int, operation: str) -> io.NodeOutput: + result = cls.OPS[operation](a, b) + return io.NodeOutput(int(result), bool(result)) + +def _get_frame_count(x, dim): + """Return the size along the given dimension.""" + if isinstance(x, dict) and "samples" in x: + s = x["samples"] + return s.tensors[0].shape[dim] if _is_nested(s) else s.shape[dim] + return x.shape[dim] + +def _get_cat_dim(x): + """Return the concatenation dimension for a tensor: dim 2 for 5D video, dim 0 for everything else.""" + if isinstance(x, dict): + s = x["samples"] + if _is_nested(s): + return 2 # NestedTensor sub-tensors use temporal dim + return 2 if s.ndim == 5 else 0 + return 0 # IMAGE/MASK: batch dim + +def _slice_tensor(x, dim, start=None, end=None): + """Slice a tensor, latent dict, or NestedTensor along the given dim.""" + if isinstance(x, dict) and "samples" in x: + result = x.copy() + samples = x["samples"] + if _is_nested(samples): + sliced = [] + for t in samples.tensors: + sliced.append(t[(slice(None),) * dim + (slice(start, end),)]) + result["samples"] = NestedTensor(sliced) + mask = x.get("noise_mask") + if mask is not None and _is_nested(mask): + mask_sliced = [] + for t in mask.tensors: + mask_sliced.append(t[(slice(None),) * dim + (slice(start, end),)]) + result["noise_mask"] = NestedTensor(mask_sliced) + else: + result["samples"] = samples[(slice(None),) * dim + (slice(start, end),)] + mask = x.get("noise_mask") + if mask is not None and mask.ndim == samples.ndim: + result["noise_mask"] = mask[(slice(None),) * dim + (slice(start, end),)] + return result + else: + return x[(slice(None),) * dim + (slice(start, end),)] + +def _concat_tensor(a, b, dim): + """Concatenate two tensors, latent dicts, or NestedTensors along the given dim.""" + if isinstance(a, dict) and "samples" in a: + result = a.copy() + sa, sb = a["samples"], b["samples"] + if _is_nested(sa): + result["samples"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(sa.tensors, sb.tensors)]) + ma, mb = a.get("noise_mask"), b.get("noise_mask") + if ma is not None and mb is not None and _is_nested(ma): + result["noise_mask"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(ma.tensors, mb.tensors)]) + else: + result["samples"] = torch.cat([sa, sb], dim=dim) + ma, mb = a.get("noise_mask"), b.get("noise_mask") + if ma is not None and mb is not None and ma.ndim == sa.ndim: + result["noise_mask"] = torch.cat([ma, mb], dim=dim) + return result + else: + return torch.cat([a, b], dim=dim) + +def _blend_overlap(items, overlap_frames, mode): + """Concatenate items with crossfade blending in the overlap regions.""" + if not items: + return None + dim = _get_cat_dim(items[0]) + # Clamp overlap to the smallest item size to avoid slicing beyond bounds + min_frames = min(_get_frame_count(item, dim) for item in items) + overlap_frames = min(overlap_frames, min_frames - 1) if min_frames > 1 else 0 + if overlap_frames <= 0: + # Nothing to blend — fall through to simple concat + return _concat_tensor(items[0], items[1], dim) if len(items) == 2 else items[0] + t = torch.linspace(0, 1, overlap_frames) + if mode == "fade_smooth": + t = t * t * (3 - 2 * t) + + result = items[0] + for i in range(1, len(items)): + prev_tail = _slice_tensor(result, dim, start=-overlap_frames) + curr_head = _slice_tensor(items[i], dim, end=overlap_frames) + curr_rest = _slice_tensor(items[i], dim, start=overlap_frames) + result_base = _slice_tensor(result, dim, end=-overlap_frames) + + # Lerp: blended = prev_tail * (1-w) + curr_head * w + if isinstance(prev_tail, dict): + blended = prev_tail.copy() + ps, cs = prev_tail["samples"], curr_head["samples"] + if _is_nested(ps): + blended_tensors = [] + for pt, ch in zip(ps.tensors, cs.tensors): + w = t.to(pt.device).reshape([1] * dim + [overlap_frames] + [1] * (pt.ndim - dim - 1)) + blended_tensors.append(pt * (1 - w) + ch * w) + blended["samples"] = NestedTensor(blended_tensors) + else: + w = t.to(ps.device).reshape([1] * dim + [overlap_frames] + [1] * (ps.ndim - dim - 1)) + blended["samples"] = ps * (1 - w) + cs * w + else: + w = t.to(prev_tail.device).reshape([1] * dim + [overlap_frames] + [1] * (prev_tail.ndim - dim - 1)) + blended = prev_tail * (1 - w) + curr_head * w + + result = _concat_tensor(_concat_tensor(result_base, blended, dim), curr_rest, dim) + return result + + +class _AccumulationToImageBatch(io.ComfyNode): + """Concatenates an ACCUMULATION of IMAGE/MASK tensors or LATENT dicts into a single batch.""" + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_AccumulationToImageBatch", + display_name="Accumulation to Batch", + category="looping/accumulation", + is_dev_only=True, + inputs=[ + io.Accumulation.Input("accumulation"), + io.Int.Input("overlap_frames", default=0, min=0), + io.Combo.Input("overlap_mode", options=["disabled", "fade_linear", "fade_smooth"], default="disabled"), + ], + outputs=[io.AnyType.Output("result")], + ) + + @classmethod + def execute(cls, accumulation, overlap_frames=0, overlap_mode="disabled") -> io.NodeOutput: + items = accumulation["accum"] + if not items: + return io.NodeOutput(None) + + # Fade modes: blend overlap regions between consecutive items + if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0 and len(items) > 1: + return io.NodeOutput(_blend_overlap(items, overlap_frames, overlap_mode)) + + # Standard concatenation (no overlap or start/end trim was handled per-iteration) + if isinstance(items[0], dict): + samples = items[0]["samples"] + if _is_nested(samples): + num_sub = len(samples.tensors) + catted = [] + for i in range(num_sub): + catted.append(torch.cat([item["samples"].tensors[i] for item in items], dim=2)) + result = items[0].copy() + result["samples"] = NestedTensor(catted) + result.pop("noise_mask", None) + masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None] + if masks and _is_nested(masks[0]): + mask_catted = [] + for i in range(len(masks[0].tensors)): + mask_catted.append(torch.cat([m.tensors[i] for m in masks], dim=2)) + result["noise_mask"] = NestedTensor(mask_catted) + return io.NodeOutput(result) + elif samples.ndim == 5: + result = items[0].copy() + result["samples"] = torch.cat([item["samples"] for item in items], dim=2) + result.pop("noise_mask", None) + masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None] + if masks and masks[0].ndim == 5: + result["noise_mask"] = torch.cat(masks, dim=2) + return io.NodeOutput(result) + else: + # Image latent — batch along dim 0 + result = items[0].copy() + result["samples"] = torch.cat([item["samples"] for item in items], dim=0) + return io.NodeOutput(result) + else: + return io.NodeOutput(torch.cat(items, dim=0)) + + +class _ConditionalSelect(io.ComfyNode): + """Returns value_if_true when condition is True, value_if_false otherwise. + Uses lazy evaluation so only the selected branch is resolved.""" + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_ConditionalSelect", + display_name="Conditional Select", + category="looping/logic", + is_dev_only=True, + inputs=[ + io.Boolean.Input("condition"), + io.AnyType.Input("value_if_true", lazy=True), + io.AnyType.Input("value_if_false", lazy=True), + ], + outputs=[io.AnyType.Output("result")], + ) + + @classmethod + def check_lazy_status(cls, condition, value_if_true=None, value_if_false=None) -> list[str]: + # Unevaluated lazy inputs arrive as None + if condition and value_if_true is None: + return ["value_if_true"] + if not condition and value_if_false is None: + return ["value_if_false"] + return [] + + @classmethod + def execute(cls, condition, **kwargs) -> io.NodeOutput: + selected = kwargs.get("value_if_true") if condition else kwargs.get("value_if_false") + return io.NodeOutput(selected) + + +class _ImageAccumStatePack(io.ComfyNode): + """Packs loop state into a single dict for TensorLoopOpen's initial_value0.""" + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_ImageAccumStatePack", + display_name="Image Accum State Pack", + category="looping/accumulation", + is_dev_only=True, + inputs=[ + io.AnyType.Input("remaining"), + io.Accumulation.Input("accum", optional=True), + io.AnyType.Input("previous_value"), + io.AnyType.Input("count"), + io.AnyType.Input("open_node_id"), + io.AnyType.Input("total_frames"), + io.Int.Input("prev_accumulated_count", default=0), + ], + outputs=[ + io.AnyType.Output("loop_state"), + io.Boolean.Output("should_continue"), + ], + ) + + @classmethod + def execute(cls, remaining, accum, previous_value, count, open_node_id, total_frames, prev_accumulated_count=0) -> io.NodeOutput: + accumulated_count = _accum_count(accum) + + if total_frames > 0: + should_continue = accumulated_count < total_frames + # Bail if the last iteration added nothing — the loop would never reach the target + if accumulated_count == prev_accumulated_count: + should_continue = False + comfy.utils.ProgressBar(total_frames, node_id=open_node_id).update_absolute(accumulated_count) + else: + should_continue = remaining > 0 + current_iteration = count - remaining + comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(current_iteration) + + return io.NodeOutput( + {"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames}, + should_continue, + ) + + +class _ImageAccumStateUnpack(io.ComfyNode): + """Unpacks loop_state from TensorLoopOpen.""" + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_ImageAccumStateUnpack", + display_name="Image Accum State Unpack", + category="looping/accumulation", + is_dev_only=True, + inputs=[io.AnyType.Input("loop_state")], + outputs=[ + io.Int.Output("remaining"), + io.Accumulation.Output("accumulation"), + io.AnyType.Output("previous_value"), + io.Int.Output("accumulated_count"), + io.Int.Output("count"), + io.AnyType.Output("open_node_id"), + io.Int.Output("total_frames"), + ], + ) + + @classmethod + def execute(cls, loop_state) -> io.NodeOutput: + remaining = loop_state["remaining"] + accum = loop_state["accum"] + previous_value = loop_state["previous_value"] + count = loop_state.get("count", 0) + open_node_id = loop_state.get("open_node_id") + total_frames = loop_state.get("total_frames", 0) + accumulated_count = _accum_count(accum) + return io.NodeOutput(remaining, accum, previous_value, accumulated_count, count, open_node_id, total_frames) + + + +class _BatchOps(io.ComfyNode): + """Batch slicing, trimming, and concatenation operations.""" + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="_BatchOps", + display_name="Batch Ops", + category="looping/accumulation", + is_dev_only=True, + inputs=[ + io.AnyType.Input("batch"), + io.Combo.Input("operation", options=["max_count", "trim_start", "trim_end", "keep_end", "concat"]), + io.Int.Input("amount", default=0, min=0), + io.AnyType.Input("batch_b", optional=True), + ], + outputs=[io.AnyType.Output("output")], + ) + + @classmethod + def execute(cls, batch, operation, amount=0, batch_b=None) -> io.NodeOutput: + dim = _get_cat_dim(batch) + if operation == "concat" and batch_b is not None: + return io.NodeOutput(_concat_tensor(batch, batch_b, dim)) + if amount <= 0: + return io.NodeOutput(batch) + # Clamp to avoid slicing beyond the tensor size + total = _get_frame_count(batch, dim) + amount = min(amount, total) + if operation == "max_count": + return io.NodeOutput(_slice_tensor(batch, dim, end=amount)) + elif operation == "trim_start": + return io.NodeOutput(_slice_tensor(batch, dim, start=amount)) + elif operation == "trim_end": + return io.NodeOutput(_slice_tensor(batch, dim, end=-amount) if amount < total else batch) + elif operation == "keep_end": + return io.NodeOutput(_slice_tensor(batch, dim, start=-amount)) + return io.NodeOutput(batch) + + +class LoopExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TensorLoopOpen, + TensorLoopClose, + _WhileLoopOpen, + _WhileLoopClose, + _AccumulateNode, + _IntOperations, + _AccumulationToImageBatch, + _ConditionalSelect, + _ImageAccumStateUnpack, + _ImageAccumStatePack, + _BatchOps, + ] + +def comfy_entrypoint(): + return LoopExtension() diff --git a/nodes.py b/nodes.py index 37ceac2fc..074c47543 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,6 +2457,7 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", + "nodes_looping.py", ] import_failed = []