import torch from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_api.latest import ComfyExtension, io from comfy.nested_tensor import NestedTensor import comfy.utils NUM_FLOW_SOCKETS = 5 def _is_nested(x): return isinstance(x, NestedTensor) def _temporal_frame_count(samples): """Count temporal frames from a samples tensor or NestedTensor.""" if _is_nested(samples): return samples.tensors[0].shape[2] # count from first sub-tensor (video) if samples.ndim == 5: return samples.shape[2] # (B,C,T,H,W) return samples.shape[0] # (B,C,H,W) — batch count def _accum_count(accum): """Count items in an accumulation, handling tensors (Image/Mask) and dicts (Latent/Video Latent).""" if not isinstance(accum, dict) or "accum" not in accum: return 0 total = 0 for item in accum["accum"]: if isinstance(item, dict): total += _temporal_frame_count(item["samples"]) else: total += item.shape[0] # IMAGE/MASK: count batch items return total class TensorLoopOpen(io.ComfyNode): """ Opens a loop that collects outputs. Supports two modes: - iterations: runs a fixed number of iterations - total_frames: runs until the target number of frames is accumulated Wire flow_control → TensorLoopClose, use `previous_value` as input to your generation, and connect the generated output → TensorLoopClose.processed. Supports IMAGE, MASK, and LATENT types. """ MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent]) @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="TensorLoopOpen", display_name="Tensor Loop Open", category="looping/accumulation", inputs=[ io.DynamicCombo.Input("mode", tooltip="Loop termination mode.", options=[ io.DynamicCombo.Option("iterations", [ io.Int.Input("iterations", default=4, min=0, tooltip="Number of loop iterations. Set to 0 to skip looping and pass through initial_value."), ]), io.DynamicCombo.Option("total_frames", [ io.Int.Input("total_frames", default=100, min=1, tooltip="Target number of output frames. The loop continues until this many frames are accumulated."), ]), ]), io.MatchType.Input("initial_value", template=cls.MATCHTYPE, optional=True, tooltip="Optional value to use as `previous_value` on the first iteration."), ], outputs=[ io.FlowControl.Output("flow_control"), io.MatchType.Output(cls.MATCHTYPE, id="previous_value", tooltip="The value from the previous iteration (or initial_value on first pass)."), io.Int.Output("accumulated_count", tooltip="Number of items collected so far (0 on first iteration)."), io.Int.Output("current_iteration", tooltip="Current iteration index (1-based)."), ], hidden=[io.Hidden.unique_id], accept_all_inputs=True, ) @classmethod def execute(cls, mode: dict, initial_value=None, **kwargs) -> io.NodeOutput: unique_id = cls.hidden.unique_id state = kwargs.get("initial_value0") # packed state dict or None on first pass if state is not None: count = state["count"] total_frames_val = state["total_frames"] remaining = state["remaining"] accum = state["accum"] previous_value = state["previous_value"] open_node_id = state["open_node_id"] else: selected_mode = mode.get("mode", "iterations") count = mode.get("iterations", 4) if selected_mode == "iterations" else 0 total_frames_val = mode.get("total_frames", 100) if selected_mode == "total_frames" else 0 remaining = count accum = None previous_value = initial_value open_node_id = unique_id accumulated_count = _accum_count(accum) # In total_frames mode, count=0 and remaining goes negative each iteration. # The math still produces correct 1-based iteration: 0-0+1=1, 0-(-1)+1=2, etc. current_iteration = count - remaining + 1 if total_frames_val > 0: comfy.utils.ProgressBar(total_frames_val, node_id=open_node_id).update_absolute(accumulated_count) elif count > 0: comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(count - remaining) loop_state = {"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames_val} return io.NodeOutput(loop_state, previous_value, accumulated_count, current_iteration) class TensorLoopClose(io.ComfyNode): """ Closes the loop started by TensorLoopOpen. Connect: - flow_control from TensorLoopOpen - processed: the output generated this iteration Supports IMAGE, MASK, and LATENT types. """ MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent]) @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="TensorLoopClose", display_name="Tensor Loop Close", category="looping/accumulation", inputs=[ io.FlowControl.Input("flow_control", raw_link=True), io.MatchType.Input("processed", template=cls.MATCHTYPE, raw_link=True, tooltip="Output generated this iteration."), io.Boolean.Input("accumulate", default=True, tooltip="When enabled, collects all iterations into a batch. When disabled, only outputs the final iteration's result."), io.DynamicCombo.Input("overlap", tooltip="Remove or blend duplicate frames where consecutive iterations overlap.", options=[ io.DynamicCombo.Option("disabled", []), io.DynamicCombo.Option("start", [ io.Int.Input("overlap_frames", default=8, min=1, tooltip="Number of frames to trim. Use when the model re-generates context frames at the start of its output (most common for video continuation)."), ]), io.DynamicCombo.Option("end", [ io.Int.Input("overlap_frames", default=8, min=1, tooltip="Number of frames to trim. Use when the model generates look-ahead frames at the end of its output."), ]), io.DynamicCombo.Option("fade_linear", [ io.Int.Input("overlap_frames", default=8, min=1, tooltip="Number of frames to crossfade with a linear blend between consecutive iterations."), ]), io.DynamicCombo.Option("fade_smooth", [ io.Int.Input("overlap_frames", default=8, min=1, tooltip="Number of frames to crossfade with a smoothstep (ease in/out) blend between consecutive iterations."), ]), ]), io.Boolean.Input("stop", optional=True, default=False, raw_link=True, force_input=True, tooltip="Optional early stop signal from inside the loop body. When True, the loop stops after the current iteration regardless of remaining iterations or total_frames target."), ], outputs=[ io.MatchType.Output(cls.MATCHTYPE, id="output", tooltip="Accumulated batch or final iteration result, depending on 'accumulate' setting."), ], enable_expand=True, ) @classmethod def execute(cls, flow_control, processed, accumulate=True, overlap=None, stop=False) -> io.NodeOutput: graph = GraphBuilder() open_id = flow_control[0] unpack = graph.node("_ImageAccumStateUnpack", loop_state=[open_id, 0]) # unpack: 0=remaining, 1=accum, 2=previous_value, 3=accumulated_count, 4=count, 5=open_node_id, 6=total_frames sub = graph.node("_IntOperations", operation="subtract", a=unpack.out(0), b=1) overlap_mode = "disabled" overlap_frames = 0 if isinstance(overlap, dict): overlap_mode = overlap.get("overlap", "disabled") overlap_frames = overlap.get("overlap_frames", 0) accum_out = unpack.out(1) if accumulate: to_accum = processed if overlap_frames > 0 and overlap_mode != "disabled": is_first = graph.node("_IntOperations", a=unpack.out(3), b=0, operation="==") trimmed_start = graph.node("_BatchOps", batch=processed, operation="trim_start", amount=overlap_frames).out(0) if overlap_mode == "start": to_accum = trimmed_start elif overlap_mode == "end": trimmed_end = graph.node("_BatchOps", batch=processed, operation="trim_end", amount=overlap_frames).out(0) trimmed_both = graph.node("_BatchOps", batch=trimmed_end, operation="trim_start", amount=overlap_frames).out(0) to_accum = graph.node("_ConditionalSelect", condition=is_first.out(1), value_if_true=trimmed_both, value_if_false=trimmed_end, ).out(0) else: # Fade: trim start on iter 1 only, keep full on subsequent for post-loop blend to_accum = graph.node("_ConditionalSelect", condition=is_first.out(1), value_if_true=trimmed_start, value_if_false=processed, ).out(0) accum_out = graph.node("_AccumulateNode", to_add=to_accum, accumulation=accum_out).out(0) # Disable total_frames when not accumulating to avoid infinite loops pack = graph.node("_ImageAccumStatePack", remaining=sub.out(0), accum=accum_out, previous_value=processed, count=unpack.out(4), open_node_id=unpack.out(5), total_frames=unpack.out(6) if accumulate else 0, prev_accumulated_count=unpack.out(3), ) # Optional early stop from loop body if is_link(stop): condition = graph.node("_ConditionalSelect", condition=stop, value_if_true=False, value_if_false=pack.out(1), ).out(0) else: condition = pack.out(1) while_close = graph.node( "_WhileLoopClose", flow_control=flow_control, condition=condition, initial_value0=pack.out(0), ) final_unpack = graph.node("_ImageAccumStateUnpack", loop_state=while_close.out(0)) if accumulate: if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0: result = graph.node("_AccumulationToImageBatch", accumulation=final_unpack.out(1), overlap_frames=overlap_frames, overlap_mode=overlap_mode, ).out(0) else: result = graph.node("_AccumulationToImageBatch", accumulation=final_unpack.out(1)).out(0) # End mode: re-append the tail that was trimmed from the last iteration if overlap_mode == "end" and overlap_frames > 0: tail = graph.node("_BatchOps", batch=final_unpack.out(2), operation="keep_end", amount=overlap_frames).out(0) result = graph.node("_BatchOps", batch=result, operation="concat", batch_b=tail).out(0) result = graph.node("_BatchOps", batch=result, operation="max_count", amount=final_unpack.out(6)).out(0) else: result = final_unpack.out(2) # Bypass: when iterations==0 AND total_frames==0, return initial_value directly count_is_zero = graph.node("_IntOperations", a=unpack.out(4), b=0, operation="==") tf_is_zero = graph.node("_IntOperations", a=unpack.out(6), b=0, operation="==") both_zero = graph.node("_IntOperations", a=count_is_zero.out(0), b=tf_is_zero.out(0), operation="multiply") result = graph.node("_ConditionalSelect", condition=both_zero.out(1), value_if_true=unpack.out(2), value_if_false=result, ).out(0) return io.NodeOutput(result, expand=graph.finalize()) # Internal helper nodes — dev only, hidden from the node menu. class _AccumulateNode(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_AccumulateNode", display_name="Accumulate", category="looping/accumulation", is_dev_only=True, inputs=[ io.AnyType.Input("to_add"), io.Accumulation.Input("accumulation", optional=True), ], outputs=[ io.Accumulation.Output(), ], ) @classmethod def execute(cls, to_add, accumulation=None) -> io.NodeOutput: if accumulation is None: value = [to_add] else: value = accumulation["accum"] + [to_add] return io.NodeOutput({"accum": value}) class _WhileLoopOpen(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_WhileLoopOpen", display_name="While Loop Open", category="looping", is_dev_only=True, inputs=[ io.Boolean.Input("condition", default=True), *[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)], ], outputs=[ io.FlowControl.Output("flow_control", display_name="FLOW_CONTROL"), *[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)], ], accept_all_inputs=True, ) @classmethod def execute(cls, condition: bool, **kwargs) -> io.NodeOutput: values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)] return io.NodeOutput("stub", *values) class _WhileLoopClose(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_WhileLoopClose", display_name="While Loop Close", category="looping", is_dev_only=True, inputs=[ io.FlowControl.Input("flow_control", raw_link=True), io.Boolean.Input("condition", force_input=True), *[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)], ], outputs=[ *[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)], ], hidden=[io.Hidden.dynprompt, io.Hidden.unique_id], enable_expand=True, accept_all_inputs=True, ) @staticmethod def _explore_dependencies(node_id, dynprompt, upstream): node_info = dynprompt.get_node(node_id) if "inputs" not in node_info: return for k, v in node_info["inputs"].items(): if is_link(v): parent_id = v[0] if parent_id not in upstream: upstream[parent_id] = [] _WhileLoopClose._explore_dependencies(parent_id, dynprompt, upstream) upstream[parent_id].append(node_id) @staticmethod def _collect_contained(node_id, upstream, contained): if node_id not in upstream: return for child_id in upstream[node_id]: if child_id not in contained: contained[child_id] = True _WhileLoopClose._collect_contained(child_id, upstream, contained) @classmethod def execute(cls, flow_control, condition: bool, **kwargs) -> io.NodeOutput: dynprompt = cls.hidden.dynprompt unique_id = cls.hidden.unique_id values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)] if not condition: # Done with the loop — return current values return io.NodeOutput(*values) # Build the graph expansion for the next loop iteration upstream = {} cls._explore_dependencies(unique_id, dynprompt, upstream) contained = {} open_node = flow_control[0] cls._collect_contained(open_node, upstream, contained) contained[unique_id] = True contained[open_node] = True # Use "Recurse" for this node's clone to avoid exponential name growth graph = GraphBuilder() for node_id in contained: original_node = dynprompt.get_node(node_id) node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) node.set_override_display_id(node_id) for node_id in contained: original_node = dynprompt.get_node(node_id) node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) for k, v in original_node["inputs"].items(): if is_link(v) and v[0] in contained: parent = graph.lookup_node(v[0]) node.set_input(k, parent.out(v[1])) else: node.set_input(k, v) new_open = graph.lookup_node(open_node) for i in range(NUM_FLOW_SOCKETS): new_open.set_input(f"initial_value{i}", values[i]) my_clone = graph.lookup_node("Recurse") result = tuple(my_clone.out(x) for x in range(NUM_FLOW_SOCKETS)) return io.NodeOutput(*result, expand=graph.finalize()) class _IntOperations(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_IntOperations", display_name="Int Operations", category="looping/logic", is_dev_only=True, inputs=[ io.Int.Input("a", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1), io.Int.Input("b", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1), io.Combo.Input("operation", options=[ "add", "subtract", "multiply", "divide", "modulo", "power", "==", "!=", "<", ">", "<=", ">=", ]), ], outputs=[ io.Int.Output(), io.Boolean.Output(), ], ) OPS = { "add": lambda a, b: a + b, "subtract": lambda a, b: a - b, "multiply": lambda a, b: a * b, "divide": lambda a, b: a // b if b else 0, "modulo": lambda a, b: a % b if b else 0, "power": lambda a, b: a ** b, "==": lambda a, b: a == b, "!=": lambda a, b: a != b, "<": lambda a, b: a < b, ">": lambda a, b: a > b, "<=": lambda a, b: a <= b, ">=": lambda a, b: a >= b, } @classmethod def execute(cls, a: int, b: int, operation: str) -> io.NodeOutput: result = cls.OPS[operation](a, b) return io.NodeOutput(int(result), bool(result)) def _get_frame_count(x, dim): """Return the size along the given dimension.""" if isinstance(x, dict) and "samples" in x: s = x["samples"] return s.tensors[0].shape[dim] if _is_nested(s) else s.shape[dim] return x.shape[dim] def _get_cat_dim(x): """Return the concatenation dimension for a tensor: dim 2 for 5D video, dim 0 for everything else.""" if isinstance(x, dict): s = x["samples"] if _is_nested(s): return 2 # NestedTensor sub-tensors use temporal dim return 2 if s.ndim == 5 else 0 return 0 # IMAGE/MASK: batch dim def _slice_tensor(x, dim, start=None, end=None): """Slice a tensor, latent dict, or NestedTensor along the given dim.""" if isinstance(x, dict) and "samples" in x: result = x.copy() samples = x["samples"] if _is_nested(samples): sliced = [] for t in samples.tensors: sliced.append(t[(slice(None),) * dim + (slice(start, end),)]) result["samples"] = NestedTensor(sliced) mask = x.get("noise_mask") if mask is not None and _is_nested(mask): mask_sliced = [] for t in mask.tensors: mask_sliced.append(t[(slice(None),) * dim + (slice(start, end),)]) result["noise_mask"] = NestedTensor(mask_sliced) else: result["samples"] = samples[(slice(None),) * dim + (slice(start, end),)] mask = x.get("noise_mask") if mask is not None and mask.ndim == samples.ndim: result["noise_mask"] = mask[(slice(None),) * dim + (slice(start, end),)] return result else: return x[(slice(None),) * dim + (slice(start, end),)] def _concat_tensor(a, b, dim): """Concatenate two tensors, latent dicts, or NestedTensors along the given dim.""" if isinstance(a, dict) and "samples" in a: result = a.copy() sa, sb = a["samples"], b["samples"] if _is_nested(sa): result["samples"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(sa.tensors, sb.tensors)]) ma, mb = a.get("noise_mask"), b.get("noise_mask") if ma is not None and mb is not None and _is_nested(ma): result["noise_mask"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(ma.tensors, mb.tensors)]) else: result["samples"] = torch.cat([sa, sb], dim=dim) ma, mb = a.get("noise_mask"), b.get("noise_mask") if ma is not None and mb is not None and ma.ndim == sa.ndim: result["noise_mask"] = torch.cat([ma, mb], dim=dim) return result else: return torch.cat([a, b], dim=dim) def _blend_overlap(items, overlap_frames, mode): """Concatenate items with crossfade blending in the overlap regions.""" if not items: return None dim = _get_cat_dim(items[0]) # Clamp overlap to the smallest item size to avoid slicing beyond bounds min_frames = min(_get_frame_count(item, dim) for item in items) overlap_frames = min(overlap_frames, min_frames - 1) if min_frames > 1 else 0 if overlap_frames <= 0: # Nothing to blend — fall through to simple concat return _concat_tensor(items[0], items[1], dim) if len(items) == 2 else items[0] t = torch.linspace(0, 1, overlap_frames) if mode == "fade_smooth": t = t * t * (3 - 2 * t) result = items[0] for i in range(1, len(items)): prev_tail = _slice_tensor(result, dim, start=-overlap_frames) curr_head = _slice_tensor(items[i], dim, end=overlap_frames) curr_rest = _slice_tensor(items[i], dim, start=overlap_frames) result_base = _slice_tensor(result, dim, end=-overlap_frames) # Lerp: blended = prev_tail * (1-w) + curr_head * w if isinstance(prev_tail, dict): blended = prev_tail.copy() ps, cs = prev_tail["samples"], curr_head["samples"] if _is_nested(ps): blended_tensors = [] for pt, ch in zip(ps.tensors, cs.tensors): w = t.to(pt.device).reshape([1] * dim + [overlap_frames] + [1] * (pt.ndim - dim - 1)) blended_tensors.append(pt * (1 - w) + ch * w) blended["samples"] = NestedTensor(blended_tensors) else: w = t.to(ps.device).reshape([1] * dim + [overlap_frames] + [1] * (ps.ndim - dim - 1)) blended["samples"] = ps * (1 - w) + cs * w else: w = t.to(prev_tail.device).reshape([1] * dim + [overlap_frames] + [1] * (prev_tail.ndim - dim - 1)) blended = prev_tail * (1 - w) + curr_head * w result = _concat_tensor(_concat_tensor(result_base, blended, dim), curr_rest, dim) return result class _AccumulationToImageBatch(io.ComfyNode): """Concatenates an ACCUMULATION of IMAGE/MASK tensors or LATENT dicts into a single batch.""" @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_AccumulationToImageBatch", display_name="Accumulation to Batch", category="looping/accumulation", is_dev_only=True, inputs=[ io.Accumulation.Input("accumulation"), io.Int.Input("overlap_frames", default=0, min=0), io.Combo.Input("overlap_mode", options=["disabled", "fade_linear", "fade_smooth"], default="disabled"), ], outputs=[io.AnyType.Output("result")], ) @classmethod def execute(cls, accumulation, overlap_frames=0, overlap_mode="disabled") -> io.NodeOutput: items = accumulation["accum"] if not items: return io.NodeOutput(None) # Fade modes: blend overlap regions between consecutive items if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0 and len(items) > 1: return io.NodeOutput(_blend_overlap(items, overlap_frames, overlap_mode)) # Standard concatenation (no overlap or start/end trim was handled per-iteration) if isinstance(items[0], dict): samples = items[0]["samples"] if _is_nested(samples): num_sub = len(samples.tensors) catted = [] for i in range(num_sub): catted.append(torch.cat([item["samples"].tensors[i] for item in items], dim=2)) result = items[0].copy() result["samples"] = NestedTensor(catted) result.pop("noise_mask", None) masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None] if masks and _is_nested(masks[0]): mask_catted = [] for i in range(len(masks[0].tensors)): mask_catted.append(torch.cat([m.tensors[i] for m in masks], dim=2)) result["noise_mask"] = NestedTensor(mask_catted) return io.NodeOutput(result) elif samples.ndim == 5: result = items[0].copy() result["samples"] = torch.cat([item["samples"] for item in items], dim=2) result.pop("noise_mask", None) masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None] if masks and masks[0].ndim == 5: result["noise_mask"] = torch.cat(masks, dim=2) return io.NodeOutput(result) else: # Image latent — batch along dim 0 result = items[0].copy() result["samples"] = torch.cat([item["samples"] for item in items], dim=0) return io.NodeOutput(result) else: return io.NodeOutput(torch.cat(items, dim=0)) class _ConditionalSelect(io.ComfyNode): """Returns value_if_true when condition is True, value_if_false otherwise. Uses lazy evaluation so only the selected branch is resolved.""" @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_ConditionalSelect", display_name="Conditional Select", category="looping/logic", is_dev_only=True, inputs=[ io.Boolean.Input("condition"), io.AnyType.Input("value_if_true", lazy=True), io.AnyType.Input("value_if_false", lazy=True), ], outputs=[io.AnyType.Output("result")], ) @classmethod def check_lazy_status(cls, condition, value_if_true=None, value_if_false=None) -> list[str]: # Unevaluated lazy inputs arrive as None if condition and value_if_true is None: return ["value_if_true"] if not condition and value_if_false is None: return ["value_if_false"] return [] @classmethod def execute(cls, condition, **kwargs) -> io.NodeOutput: selected = kwargs.get("value_if_true") if condition else kwargs.get("value_if_false") return io.NodeOutput(selected) class _ImageAccumStatePack(io.ComfyNode): """Packs loop state into a single dict for TensorLoopOpen's initial_value0.""" @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_ImageAccumStatePack", display_name="Image Accum State Pack", category="looping/accumulation", is_dev_only=True, inputs=[ io.AnyType.Input("remaining"), io.Accumulation.Input("accum", optional=True), io.AnyType.Input("previous_value"), io.AnyType.Input("count"), io.AnyType.Input("open_node_id"), io.AnyType.Input("total_frames"), io.Int.Input("prev_accumulated_count", default=0), ], outputs=[ io.AnyType.Output("loop_state"), io.Boolean.Output("should_continue"), ], ) @classmethod def execute(cls, remaining, accum, previous_value, count, open_node_id, total_frames, prev_accumulated_count=0) -> io.NodeOutput: accumulated_count = _accum_count(accum) if total_frames > 0: should_continue = accumulated_count < total_frames # Bail if the last iteration added nothing — the loop would never reach the target if accumulated_count == prev_accumulated_count: should_continue = False comfy.utils.ProgressBar(total_frames, node_id=open_node_id).update_absolute(accumulated_count) else: should_continue = remaining > 0 current_iteration = count - remaining comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(current_iteration) return io.NodeOutput( {"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames}, should_continue, ) class _ImageAccumStateUnpack(io.ComfyNode): """Unpacks loop_state from TensorLoopOpen.""" @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_ImageAccumStateUnpack", display_name="Image Accum State Unpack", category="looping/accumulation", is_dev_only=True, inputs=[io.AnyType.Input("loop_state")], outputs=[ io.Int.Output("remaining"), io.Accumulation.Output("accumulation"), io.AnyType.Output("previous_value"), io.Int.Output("accumulated_count"), io.Int.Output("count"), io.AnyType.Output("open_node_id"), io.Int.Output("total_frames"), ], ) @classmethod def execute(cls, loop_state) -> io.NodeOutput: remaining = loop_state["remaining"] accum = loop_state["accum"] previous_value = loop_state["previous_value"] count = loop_state.get("count", 0) open_node_id = loop_state.get("open_node_id") total_frames = loop_state.get("total_frames", 0) accumulated_count = _accum_count(accum) return io.NodeOutput(remaining, accum, previous_value, accumulated_count, count, open_node_id, total_frames) class _BatchOps(io.ComfyNode): """Batch slicing, trimming, and concatenation operations.""" @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="_BatchOps", display_name="Batch Ops", category="looping/accumulation", is_dev_only=True, inputs=[ io.AnyType.Input("batch"), io.Combo.Input("operation", options=["max_count", "trim_start", "trim_end", "keep_end", "concat"]), io.Int.Input("amount", default=0, min=0), io.AnyType.Input("batch_b", optional=True), ], outputs=[io.AnyType.Output("output")], ) @classmethod def execute(cls, batch, operation, amount=0, batch_b=None) -> io.NodeOutput: dim = _get_cat_dim(batch) if operation == "concat" and batch_b is not None: return io.NodeOutput(_concat_tensor(batch, batch_b, dim)) if amount <= 0: return io.NodeOutput(batch) # Clamp to avoid slicing beyond the tensor size total = _get_frame_count(batch, dim) amount = min(amount, total) if operation == "max_count": return io.NodeOutput(_slice_tensor(batch, dim, end=amount)) elif operation == "trim_start": return io.NodeOutput(_slice_tensor(batch, dim, start=amount)) elif operation == "trim_end": return io.NodeOutput(_slice_tensor(batch, dim, end=-amount) if amount < total else batch) elif operation == "keep_end": return io.NodeOutput(_slice_tensor(batch, dim, start=-amount)) return io.NodeOutput(batch) class LoopExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TensorLoopOpen, TensorLoopClose, _WhileLoopOpen, _WhileLoopClose, _AccumulateNode, _IntOperations, _AccumulationToImageBatch, _ConditionalSelect, _ImageAccumStateUnpack, _ImageAccumStatePack, _BatchOps, ] def comfy_entrypoint(): return LoopExtension()