This commit is contained in:
Jukka Seppänen 2026-03-28 11:19:06 -06:00 committed by GitHub
commit 0b699b9b2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 780 additions and 6 deletions

View File

@ -79,14 +79,15 @@ class FeatureProjection(nn.Module):
class PositionalConvEmbedding(nn.Module): class PositionalConvEmbedding(nn.Module):
def __init__(self, embed_dim=768, kernel_size=128, groups=16): def __init__(self, embed_dim=768, kernel_size=128, groups=16, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.conv = nn.Conv1d( self.conv = operations.Conv1d(
embed_dim, embed_dim,
embed_dim, embed_dim,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=kernel_size // 2, padding=kernel_size // 2,
groups=groups, groups=groups,
device=device, dtype=dtype
) )
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
self.activation = nn.GELU() self.activation = nn.GELU()
@ -111,7 +112,7 @@ class TransformerEncoder(nn.Module):
): ):
super().__init__() super().__init__()
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim) self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
TransformerEncoderLayer( TransformerEncoderLayer(
embed_dim=embed_dim, embed_dim=embed_dim,

View File

@ -0,0 +1,767 @@
import torch
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy_api.latest import ComfyExtension, io
from comfy.nested_tensor import NestedTensor
import comfy.utils
NUM_FLOW_SOCKETS = 5
def _is_nested(x):
return isinstance(x, NestedTensor)
def _temporal_frame_count(samples):
"""Count temporal frames from a samples tensor or NestedTensor."""
if _is_nested(samples):
return samples.tensors[0].shape[2] # count from first sub-tensor (video)
if samples.ndim == 5:
return samples.shape[2] # (B,C,T,H,W)
return samples.shape[0] # (B,C,H,W) — batch count
def _accum_count(accum):
"""Count items in an accumulation, handling tensors (Image/Mask) and dicts (Latent/Video Latent)."""
if not isinstance(accum, dict) or "accum" not in accum:
return 0
total = 0
for item in accum["accum"]:
if isinstance(item, dict):
total += _temporal_frame_count(item["samples"])
else:
total += item.shape[0] # IMAGE/MASK: count batch items
return total
class TensorLoopOpen(io.ComfyNode):
"""
Opens a loop that collects outputs. Supports two modes:
- iterations: runs a fixed number of iterations
- total_frames: runs until the target number of frames is accumulated
Wire flow_control TensorLoopClose, use `previous_value` as input to your
generation, and connect the generated output TensorLoopClose.processed.
Supports IMAGE, MASK, and LATENT types.
"""
MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent])
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="TensorLoopOpen",
display_name="Tensor Loop Open",
category="looping/accumulation",
inputs=[
io.DynamicCombo.Input("mode", tooltip="Loop termination mode.", options=[
io.DynamicCombo.Option("iterations", [
io.Int.Input("iterations", default=4, min=0,
tooltip="Number of loop iterations. Set to 0 to skip looping and pass through initial_value."),
]),
io.DynamicCombo.Option("total_frames", [
io.Int.Input("total_frames", default=100, min=1,
tooltip="Target number of output frames. The loop continues until this many frames are accumulated."),
]),
]),
io.MatchType.Input("initial_value", template=cls.MATCHTYPE, optional=True,
tooltip="Optional value to use as `previous_value` on the first iteration."),
],
outputs=[
io.FlowControl.Output("flow_control"),
io.MatchType.Output(cls.MATCHTYPE, id="previous_value",
tooltip="The value from the previous iteration (or initial_value on first pass)."),
io.Int.Output("accumulated_count", tooltip="Number of items collected so far (0 on first iteration)."),
io.Int.Output("current_iteration", tooltip="Current iteration index (1-based)."),
],
hidden=[io.Hidden.unique_id],
accept_all_inputs=True,
)
@classmethod
def execute(cls, mode: dict, initial_value=None, **kwargs) -> io.NodeOutput:
unique_id = cls.hidden.unique_id
state = kwargs.get("initial_value0") # packed state dict or None on first pass
if state is not None:
count = state["count"]
total_frames_val = state["total_frames"]
remaining = state["remaining"]
accum = state["accum"]
previous_value = state["previous_value"]
open_node_id = state["open_node_id"]
else:
selected_mode = mode.get("mode", "iterations")
count = mode.get("iterations", 4) if selected_mode == "iterations" else 0
total_frames_val = mode.get("total_frames", 100) if selected_mode == "total_frames" else 0
remaining = count
accum = None
previous_value = initial_value
open_node_id = unique_id
accumulated_count = _accum_count(accum)
# In total_frames mode, count=0 and remaining goes negative each iteration.
# The math still produces correct 1-based iteration: 0-0+1=1, 0-(-1)+1=2, etc.
current_iteration = count - remaining + 1
if total_frames_val > 0:
comfy.utils.ProgressBar(total_frames_val, node_id=open_node_id).update_absolute(accumulated_count)
elif count > 0:
comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(count - remaining)
loop_state = {"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames_val}
return io.NodeOutput(loop_state, previous_value, accumulated_count, current_iteration)
class TensorLoopClose(io.ComfyNode):
"""
Closes the loop started by TensorLoopOpen.
Connect:
- flow_control from TensorLoopOpen
- processed: the output generated this iteration
Supports IMAGE, MASK, and LATENT types.
"""
MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent])
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="TensorLoopClose",
display_name="Tensor Loop Close",
category="looping/accumulation",
inputs=[
io.FlowControl.Input("flow_control", raw_link=True),
io.MatchType.Input("processed", template=cls.MATCHTYPE, raw_link=True, tooltip="Output generated this iteration."),
io.Boolean.Input("accumulate", default=True,
tooltip="When enabled, collects all iterations into a batch. When disabled, only outputs the final iteration's result."),
io.DynamicCombo.Input("overlap", tooltip="Remove or blend duplicate frames where consecutive iterations overlap.", options=[
io.DynamicCombo.Option("disabled", []),
io.DynamicCombo.Option("start", [
io.Int.Input("overlap_frames", default=8, min=1,
tooltip="Number of frames to trim. Use when the model re-generates context frames at the start of its output (most common for video continuation)."),
]),
io.DynamicCombo.Option("end", [
io.Int.Input("overlap_frames", default=8, min=1,
tooltip="Number of frames to trim. Use when the model generates look-ahead frames at the end of its output."),
]),
io.DynamicCombo.Option("fade_linear", [
io.Int.Input("overlap_frames", default=8, min=1,
tooltip="Number of frames to crossfade with a linear blend between consecutive iterations."),
]),
io.DynamicCombo.Option("fade_smooth", [
io.Int.Input("overlap_frames", default=8, min=1,
tooltip="Number of frames to crossfade with a smoothstep (ease in/out) blend between consecutive iterations."),
]),
]),
io.Boolean.Input("stop", optional=True, default=False, raw_link=True, force_input=True,
tooltip="Optional early stop signal from inside the loop body. When True, the loop stops after the current iteration regardless of remaining iterations or total_frames target."),
],
outputs=[
io.MatchType.Output(cls.MATCHTYPE, id="output", tooltip="Accumulated batch or final iteration result, depending on 'accumulate' setting."),
],
enable_expand=True,
)
@classmethod
def execute(cls, flow_control, processed, accumulate=True, overlap=None, stop=False) -> io.NodeOutput:
graph = GraphBuilder()
open_id = flow_control[0]
unpack = graph.node("_ImageAccumStateUnpack", loop_state=[open_id, 0])
# unpack: 0=remaining, 1=accum, 2=previous_value, 3=accumulated_count, 4=count, 5=open_node_id, 6=total_frames
sub = graph.node("_IntOperations", operation="subtract", a=unpack.out(0), b=1)
overlap_mode = "disabled"
overlap_frames = 0
if isinstance(overlap, dict):
overlap_mode = overlap.get("overlap", "disabled")
overlap_frames = overlap.get("overlap_frames", 0)
accum_out = unpack.out(1)
if accumulate:
to_accum = processed
if overlap_frames > 0 and overlap_mode != "disabled":
is_first = graph.node("_IntOperations", a=unpack.out(3), b=0, operation="==")
trimmed_start = graph.node("_BatchOps", batch=processed, operation="trim_start", amount=overlap_frames).out(0)
if overlap_mode == "start":
to_accum = trimmed_start
elif overlap_mode == "end":
trimmed_end = graph.node("_BatchOps", batch=processed, operation="trim_end", amount=overlap_frames).out(0)
trimmed_both = graph.node("_BatchOps", batch=trimmed_end, operation="trim_start", amount=overlap_frames).out(0)
to_accum = graph.node("_ConditionalSelect",
condition=is_first.out(1),
value_if_true=trimmed_both,
value_if_false=trimmed_end,
).out(0)
else:
# Fade: trim start on iter 1 only, keep full on subsequent for post-loop blend
to_accum = graph.node("_ConditionalSelect",
condition=is_first.out(1),
value_if_true=trimmed_start,
value_if_false=processed,
).out(0)
accum_out = graph.node("_AccumulateNode", to_add=to_accum, accumulation=accum_out).out(0)
# Disable total_frames when not accumulating to avoid infinite loops
pack = graph.node("_ImageAccumStatePack",
remaining=sub.out(0),
accum=accum_out,
previous_value=processed,
count=unpack.out(4),
open_node_id=unpack.out(5),
total_frames=unpack.out(6) if accumulate else 0,
prev_accumulated_count=unpack.out(3),
)
# Optional early stop from loop body
if is_link(stop):
condition = graph.node("_ConditionalSelect",
condition=stop,
value_if_true=False,
value_if_false=pack.out(1),
).out(0)
else:
condition = pack.out(1)
while_close = graph.node(
"_WhileLoopClose",
flow_control=flow_control,
condition=condition,
initial_value0=pack.out(0),
)
final_unpack = graph.node("_ImageAccumStateUnpack", loop_state=while_close.out(0))
if accumulate:
if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0:
result = graph.node("_AccumulationToImageBatch",
accumulation=final_unpack.out(1),
overlap_frames=overlap_frames,
overlap_mode=overlap_mode,
).out(0)
else:
result = graph.node("_AccumulationToImageBatch", accumulation=final_unpack.out(1)).out(0)
# End mode: re-append the tail that was trimmed from the last iteration
if overlap_mode == "end" and overlap_frames > 0:
tail = graph.node("_BatchOps", batch=final_unpack.out(2), operation="keep_end", amount=overlap_frames).out(0)
result = graph.node("_BatchOps", batch=result, operation="concat", batch_b=tail).out(0)
result = graph.node("_BatchOps", batch=result, operation="max_count", amount=final_unpack.out(6)).out(0)
else:
result = final_unpack.out(2)
# Bypass: when iterations==0 AND total_frames==0, return initial_value directly
count_is_zero = graph.node("_IntOperations", a=unpack.out(4), b=0, operation="==")
tf_is_zero = graph.node("_IntOperations", a=unpack.out(6), b=0, operation="==")
both_zero = graph.node("_IntOperations", a=count_is_zero.out(0), b=tf_is_zero.out(0), operation="multiply")
result = graph.node("_ConditionalSelect",
condition=both_zero.out(1),
value_if_true=unpack.out(2),
value_if_false=result,
).out(0)
return io.NodeOutput(result, expand=graph.finalize())
# Internal helper nodes — dev only, hidden from the node menu.
class _AccumulateNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_AccumulateNode",
display_name="Accumulate",
category="looping/accumulation",
is_dev_only=True,
inputs=[
io.AnyType.Input("to_add"),
io.Accumulation.Input("accumulation", optional=True),
],
outputs=[
io.Accumulation.Output(),
],
)
@classmethod
def execute(cls, to_add, accumulation=None) -> io.NodeOutput:
if accumulation is None:
value = [to_add]
else:
value = accumulation["accum"] + [to_add]
return io.NodeOutput({"accum": value})
class _WhileLoopOpen(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_WhileLoopOpen",
display_name="While Loop Open",
category="looping",
is_dev_only=True,
inputs=[
io.Boolean.Input("condition", default=True),
*[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)],
],
outputs=[
io.FlowControl.Output("flow_control", display_name="FLOW_CONTROL"),
*[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)],
],
accept_all_inputs=True,
)
@classmethod
def execute(cls, condition: bool, **kwargs) -> io.NodeOutput:
values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)]
return io.NodeOutput("stub", *values)
class _WhileLoopClose(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_WhileLoopClose",
display_name="While Loop Close",
category="looping",
is_dev_only=True,
inputs=[
io.FlowControl.Input("flow_control", raw_link=True),
io.Boolean.Input("condition", force_input=True),
*[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)],
],
outputs=[
*[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)],
],
hidden=[io.Hidden.dynprompt, io.Hidden.unique_id],
enable_expand=True,
accept_all_inputs=True,
)
@staticmethod
def _explore_dependencies(node_id, dynprompt, upstream):
node_info = dynprompt.get_node(node_id)
if "inputs" not in node_info:
return
for k, v in node_info["inputs"].items():
if is_link(v):
parent_id = v[0]
if parent_id not in upstream:
upstream[parent_id] = []
_WhileLoopClose._explore_dependencies(parent_id, dynprompt, upstream)
upstream[parent_id].append(node_id)
@staticmethod
def _collect_contained(node_id, upstream, contained):
if node_id not in upstream:
return
for child_id in upstream[node_id]:
if child_id not in contained:
contained[child_id] = True
_WhileLoopClose._collect_contained(child_id, upstream, contained)
@classmethod
def execute(cls, flow_control, condition: bool, **kwargs) -> io.NodeOutput:
dynprompt = cls.hidden.dynprompt
unique_id = cls.hidden.unique_id
values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)]
if not condition: # Done with the loop — return current values
return io.NodeOutput(*values)
# Build the graph expansion for the next loop iteration
upstream = {}
cls._explore_dependencies(unique_id, dynprompt, upstream)
contained = {}
open_node = flow_control[0]
cls._collect_contained(open_node, upstream, contained)
contained[unique_id] = True
contained[open_node] = True
# Use "Recurse" for this node's clone to avoid exponential name growth
graph = GraphBuilder()
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
node.set_override_display_id(node_id)
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
for k, v in original_node["inputs"].items():
if is_link(v) and v[0] in contained:
parent = graph.lookup_node(v[0])
node.set_input(k, parent.out(v[1]))
else:
node.set_input(k, v)
new_open = graph.lookup_node(open_node)
for i in range(NUM_FLOW_SOCKETS):
new_open.set_input(f"initial_value{i}", values[i])
my_clone = graph.lookup_node("Recurse")
result = tuple(my_clone.out(x) for x in range(NUM_FLOW_SOCKETS))
return io.NodeOutput(*result, expand=graph.finalize())
class _IntOperations(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_IntOperations",
display_name="Int Operations",
category="looping/logic",
is_dev_only=True,
inputs=[
io.Int.Input("a", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1),
io.Int.Input("b", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1),
io.Combo.Input("operation", options=[
"add", "subtract", "multiply", "divide", "modulo", "power",
"==", "!=", "<", ">", "<=", ">=",
]),
],
outputs=[
io.Int.Output(),
io.Boolean.Output(),
],
)
OPS = {
"add": lambda a, b: a + b, "subtract": lambda a, b: a - b,
"multiply": lambda a, b: a * b, "divide": lambda a, b: a // b if b else 0,
"modulo": lambda a, b: a % b if b else 0, "power": lambda a, b: a ** b,
"==": lambda a, b: a == b, "!=": lambda a, b: a != b,
"<": lambda a, b: a < b, ">": lambda a, b: a > b,
"<=": lambda a, b: a <= b, ">=": lambda a, b: a >= b,
}
@classmethod
def execute(cls, a: int, b: int, operation: str) -> io.NodeOutput:
result = cls.OPS[operation](a, b)
return io.NodeOutput(int(result), bool(result))
def _get_frame_count(x, dim):
"""Return the size along the given dimension."""
if isinstance(x, dict) and "samples" in x:
s = x["samples"]
return s.tensors[0].shape[dim] if _is_nested(s) else s.shape[dim]
return x.shape[dim]
def _get_cat_dim(x):
"""Return the concatenation dimension for a tensor: dim 2 for 5D video, dim 0 for everything else."""
if isinstance(x, dict):
s = x["samples"]
if _is_nested(s):
return 2 # NestedTensor sub-tensors use temporal dim
return 2 if s.ndim == 5 else 0
return 0 # IMAGE/MASK: batch dim
def _slice_tensor(x, dim, start=None, end=None):
"""Slice a tensor, latent dict, or NestedTensor along the given dim."""
if isinstance(x, dict) and "samples" in x:
result = x.copy()
samples = x["samples"]
if _is_nested(samples):
sliced = []
for t in samples.tensors:
sliced.append(t[(slice(None),) * dim + (slice(start, end),)])
result["samples"] = NestedTensor(sliced)
mask = x.get("noise_mask")
if mask is not None and _is_nested(mask):
mask_sliced = []
for t in mask.tensors:
mask_sliced.append(t[(slice(None),) * dim + (slice(start, end),)])
result["noise_mask"] = NestedTensor(mask_sliced)
else:
result["samples"] = samples[(slice(None),) * dim + (slice(start, end),)]
mask = x.get("noise_mask")
if mask is not None and mask.ndim == samples.ndim:
result["noise_mask"] = mask[(slice(None),) * dim + (slice(start, end),)]
return result
else:
return x[(slice(None),) * dim + (slice(start, end),)]
def _concat_tensor(a, b, dim):
"""Concatenate two tensors, latent dicts, or NestedTensors along the given dim."""
if isinstance(a, dict) and "samples" in a:
result = a.copy()
sa, sb = a["samples"], b["samples"]
if _is_nested(sa):
result["samples"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(sa.tensors, sb.tensors)])
ma, mb = a.get("noise_mask"), b.get("noise_mask")
if ma is not None and mb is not None and _is_nested(ma):
result["noise_mask"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(ma.tensors, mb.tensors)])
else:
result["samples"] = torch.cat([sa, sb], dim=dim)
ma, mb = a.get("noise_mask"), b.get("noise_mask")
if ma is not None and mb is not None and ma.ndim == sa.ndim:
result["noise_mask"] = torch.cat([ma, mb], dim=dim)
return result
else:
return torch.cat([a, b], dim=dim)
def _blend_overlap(items, overlap_frames, mode):
"""Concatenate items with crossfade blending in the overlap regions."""
if not items:
return None
dim = _get_cat_dim(items[0])
# Clamp overlap to the smallest item size to avoid slicing beyond bounds
min_frames = min(_get_frame_count(item, dim) for item in items)
overlap_frames = min(overlap_frames, min_frames - 1) if min_frames > 1 else 0
if overlap_frames <= 0:
# Nothing to blend — fall through to simple concat
return _concat_tensor(items[0], items[1], dim) if len(items) == 2 else items[0]
t = torch.linspace(0, 1, overlap_frames)
if mode == "fade_smooth":
t = t * t * (3 - 2 * t)
result = items[0]
for i in range(1, len(items)):
prev_tail = _slice_tensor(result, dim, start=-overlap_frames)
curr_head = _slice_tensor(items[i], dim, end=overlap_frames)
curr_rest = _slice_tensor(items[i], dim, start=overlap_frames)
result_base = _slice_tensor(result, dim, end=-overlap_frames)
# Lerp: blended = prev_tail * (1-w) + curr_head * w
if isinstance(prev_tail, dict):
blended = prev_tail.copy()
ps, cs = prev_tail["samples"], curr_head["samples"]
if _is_nested(ps):
blended_tensors = []
for pt, ch in zip(ps.tensors, cs.tensors):
w = t.to(pt.device).reshape([1] * dim + [overlap_frames] + [1] * (pt.ndim - dim - 1))
blended_tensors.append(pt * (1 - w) + ch * w)
blended["samples"] = NestedTensor(blended_tensors)
else:
w = t.to(ps.device).reshape([1] * dim + [overlap_frames] + [1] * (ps.ndim - dim - 1))
blended["samples"] = ps * (1 - w) + cs * w
else:
w = t.to(prev_tail.device).reshape([1] * dim + [overlap_frames] + [1] * (prev_tail.ndim - dim - 1))
blended = prev_tail * (1 - w) + curr_head * w
result = _concat_tensor(_concat_tensor(result_base, blended, dim), curr_rest, dim)
return result
class _AccumulationToImageBatch(io.ComfyNode):
"""Concatenates an ACCUMULATION of IMAGE/MASK tensors or LATENT dicts into a single batch."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_AccumulationToImageBatch",
display_name="Accumulation to Batch",
category="looping/accumulation",
is_dev_only=True,
inputs=[
io.Accumulation.Input("accumulation"),
io.Int.Input("overlap_frames", default=0, min=0),
io.Combo.Input("overlap_mode", options=["disabled", "fade_linear", "fade_smooth"], default="disabled"),
],
outputs=[io.AnyType.Output("result")],
)
@classmethod
def execute(cls, accumulation, overlap_frames=0, overlap_mode="disabled") -> io.NodeOutput:
items = accumulation["accum"]
if not items:
return io.NodeOutput(None)
# Fade modes: blend overlap regions between consecutive items
if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0 and len(items) > 1:
return io.NodeOutput(_blend_overlap(items, overlap_frames, overlap_mode))
# Standard concatenation (no overlap or start/end trim was handled per-iteration)
if isinstance(items[0], dict):
samples = items[0]["samples"]
if _is_nested(samples):
num_sub = len(samples.tensors)
catted = []
for i in range(num_sub):
catted.append(torch.cat([item["samples"].tensors[i] for item in items], dim=2))
result = items[0].copy()
result["samples"] = NestedTensor(catted)
result.pop("noise_mask", None)
masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None]
if masks and _is_nested(masks[0]):
mask_catted = []
for i in range(len(masks[0].tensors)):
mask_catted.append(torch.cat([m.tensors[i] for m in masks], dim=2))
result["noise_mask"] = NestedTensor(mask_catted)
return io.NodeOutput(result)
elif samples.ndim == 5:
result = items[0].copy()
result["samples"] = torch.cat([item["samples"] for item in items], dim=2)
result.pop("noise_mask", None)
masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None]
if masks and masks[0].ndim == 5:
result["noise_mask"] = torch.cat(masks, dim=2)
return io.NodeOutput(result)
else:
# Image latent — batch along dim 0
result = items[0].copy()
result["samples"] = torch.cat([item["samples"] for item in items], dim=0)
return io.NodeOutput(result)
else:
return io.NodeOutput(torch.cat(items, dim=0))
class _ConditionalSelect(io.ComfyNode):
"""Returns value_if_true when condition is True, value_if_false otherwise.
Uses lazy evaluation so only the selected branch is resolved."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_ConditionalSelect",
display_name="Conditional Select",
category="looping/logic",
is_dev_only=True,
inputs=[
io.Boolean.Input("condition"),
io.AnyType.Input("value_if_true", lazy=True),
io.AnyType.Input("value_if_false", lazy=True),
],
outputs=[io.AnyType.Output("result")],
)
@classmethod
def check_lazy_status(cls, condition, value_if_true=None, value_if_false=None) -> list[str]:
# Unevaluated lazy inputs arrive as None
if condition and value_if_true is None:
return ["value_if_true"]
if not condition and value_if_false is None:
return ["value_if_false"]
return []
@classmethod
def execute(cls, condition, **kwargs) -> io.NodeOutput:
selected = kwargs.get("value_if_true") if condition else kwargs.get("value_if_false")
return io.NodeOutput(selected)
class _ImageAccumStatePack(io.ComfyNode):
"""Packs loop state into a single dict for TensorLoopOpen's initial_value0."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_ImageAccumStatePack",
display_name="Image Accum State Pack",
category="looping/accumulation",
is_dev_only=True,
inputs=[
io.AnyType.Input("remaining"),
io.Accumulation.Input("accum", optional=True),
io.AnyType.Input("previous_value"),
io.AnyType.Input("count"),
io.AnyType.Input("open_node_id"),
io.AnyType.Input("total_frames"),
io.Int.Input("prev_accumulated_count", default=0),
],
outputs=[
io.AnyType.Output("loop_state"),
io.Boolean.Output("should_continue"),
],
)
@classmethod
def execute(cls, remaining, accum, previous_value, count, open_node_id, total_frames, prev_accumulated_count=0) -> io.NodeOutput:
accumulated_count = _accum_count(accum)
if total_frames > 0:
should_continue = accumulated_count < total_frames
# Bail if the last iteration added nothing — the loop would never reach the target
if accumulated_count == prev_accumulated_count:
should_continue = False
comfy.utils.ProgressBar(total_frames, node_id=open_node_id).update_absolute(accumulated_count)
else:
should_continue = remaining > 0
current_iteration = count - remaining
comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(current_iteration)
return io.NodeOutput(
{"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames},
should_continue,
)
class _ImageAccumStateUnpack(io.ComfyNode):
"""Unpacks loop_state from TensorLoopOpen."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_ImageAccumStateUnpack",
display_name="Image Accum State Unpack",
category="looping/accumulation",
is_dev_only=True,
inputs=[io.AnyType.Input("loop_state")],
outputs=[
io.Int.Output("remaining"),
io.Accumulation.Output("accumulation"),
io.AnyType.Output("previous_value"),
io.Int.Output("accumulated_count"),
io.Int.Output("count"),
io.AnyType.Output("open_node_id"),
io.Int.Output("total_frames"),
],
)
@classmethod
def execute(cls, loop_state) -> io.NodeOutput:
remaining = loop_state["remaining"]
accum = loop_state["accum"]
previous_value = loop_state["previous_value"]
count = loop_state.get("count", 0)
open_node_id = loop_state.get("open_node_id")
total_frames = loop_state.get("total_frames", 0)
accumulated_count = _accum_count(accum)
return io.NodeOutput(remaining, accum, previous_value, accumulated_count, count, open_node_id, total_frames)
class _BatchOps(io.ComfyNode):
"""Batch slicing, trimming, and concatenation operations."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="_BatchOps",
display_name="Batch Ops",
category="looping/accumulation",
is_dev_only=True,
inputs=[
io.AnyType.Input("batch"),
io.Combo.Input("operation", options=["max_count", "trim_start", "trim_end", "keep_end", "concat"]),
io.Int.Input("amount", default=0, min=0),
io.AnyType.Input("batch_b", optional=True),
],
outputs=[io.AnyType.Output("output")],
)
@classmethod
def execute(cls, batch, operation, amount=0, batch_b=None) -> io.NodeOutput:
dim = _get_cat_dim(batch)
if operation == "concat" and batch_b is not None:
return io.NodeOutput(_concat_tensor(batch, batch_b, dim))
if amount <= 0:
return io.NodeOutput(batch)
# Clamp to avoid slicing beyond the tensor size
total = _get_frame_count(batch, dim)
amount = min(amount, total)
if operation == "max_count":
return io.NodeOutput(_slice_tensor(batch, dim, end=amount))
elif operation == "trim_start":
return io.NodeOutput(_slice_tensor(batch, dim, start=amount))
elif operation == "trim_end":
return io.NodeOutput(_slice_tensor(batch, dim, end=-amount) if amount < total else batch)
elif operation == "keep_end":
return io.NodeOutput(_slice_tensor(batch, dim, start=-amount))
return io.NodeOutput(batch)
class LoopExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TensorLoopOpen,
TensorLoopClose,
_WhileLoopOpen,
_WhileLoopClose,
_AccumulateNode,
_IntOperations,
_AccumulationToImageBatch,
_ConditionalSelect,
_ImageAccumStateUnpack,
_ImageAccumStatePack,
_BatchOps,
]
def comfy_entrypoint():
return LoopExtension()

View File

@ -1326,6 +1326,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context.", advanced=True), io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context.", advanced=True),
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
io.Image.Input("previous_frames", optional=True), io.Image.Input("previous_frames", optional=True),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in the previous_frames input.")
], ],
outputs=[ outputs=[
io.Model.Output(display_name="model"), io.Model.Output(display_name="model"),
@ -1338,7 +1339,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
@classmethod @classmethod
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None, video_frame_offset=0) -> io.NodeOutput:
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
raise ValueError("Not enough previous frames provided.") raise ValueError("Not enough previous frames provided.")
@ -1421,11 +1422,13 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
# when extending from previous frames # when extending from previous frames
if previous_frames is not None: if previous_frames is not None:
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if video_frame_offset is not None and video_frame_offset > 0:
frame_offset = video_frame_offset - motion_frame_count
else:
frame_offset = previous_frames.shape[0] - motion_frame_count frame_offset = previous_frames.shape[0] - motion_frame_count
audio_start = frame_offset audio_start = frame_offset
audio_end = audio_start + length audio_end = audio_start + length
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
trim_image = motion_frame_count trim_image = motion_frame_count
@ -1434,6 +1437,8 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
audio_end = length audio_end = length
motion_frames_latent = concat_latent_image[:, :, :1] motion_frames_latent = concat_latent_image[:, :, :1]
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed

View File

@ -2457,6 +2457,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py", "nodes_number_convert.py",
"nodes_painter.py", "nodes_painter.py",
"nodes_curve.py", "nodes_curve.py",
"nodes_looping.py",
] ]
import_failed = [] import_failed = []