mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
Merge a51490b68b into fc1fdf3389
This commit is contained in:
commit
0b699b9b2b
@ -79,14 +79,15 @@ class FeatureProjection(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PositionalConvEmbedding(nn.Module):
|
class PositionalConvEmbedding(nn.Module):
|
||||||
def __init__(self, embed_dim=768, kernel_size=128, groups=16):
|
def __init__(self, embed_dim=768, kernel_size=128, groups=16, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Conv1d(
|
self.conv = operations.Conv1d(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=kernel_size // 2,
|
padding=kernel_size // 2,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
|
device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
@ -111,7 +112,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim)
|
self.pos_conv_embed = PositionalConvEmbedding(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
TransformerEncoderLayer(
|
TransformerEncoderLayer(
|
||||||
embed_dim=embed_dim,
|
embed_dim=embed_dim,
|
||||||
|
|||||||
767
comfy_extras/nodes_looping.py
Normal file
767
comfy_extras/nodes_looping.py
Normal file
@ -0,0 +1,767 @@
|
|||||||
|
import torch
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy.nested_tensor import NestedTensor
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
|
NUM_FLOW_SOCKETS = 5
|
||||||
|
|
||||||
|
def _is_nested(x):
|
||||||
|
return isinstance(x, NestedTensor)
|
||||||
|
|
||||||
|
def _temporal_frame_count(samples):
|
||||||
|
"""Count temporal frames from a samples tensor or NestedTensor."""
|
||||||
|
if _is_nested(samples):
|
||||||
|
return samples.tensors[0].shape[2] # count from first sub-tensor (video)
|
||||||
|
if samples.ndim == 5:
|
||||||
|
return samples.shape[2] # (B,C,T,H,W)
|
||||||
|
return samples.shape[0] # (B,C,H,W) — batch count
|
||||||
|
|
||||||
|
def _accum_count(accum):
|
||||||
|
"""Count items in an accumulation, handling tensors (Image/Mask) and dicts (Latent/Video Latent)."""
|
||||||
|
if not isinstance(accum, dict) or "accum" not in accum:
|
||||||
|
return 0
|
||||||
|
total = 0
|
||||||
|
for item in accum["accum"]:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
total += _temporal_frame_count(item["samples"])
|
||||||
|
else:
|
||||||
|
total += item.shape[0] # IMAGE/MASK: count batch items
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
class TensorLoopOpen(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Opens a loop that collects outputs. Supports two modes:
|
||||||
|
- iterations: runs a fixed number of iterations
|
||||||
|
- total_frames: runs until the target number of frames is accumulated
|
||||||
|
Wire flow_control → TensorLoopClose, use `previous_value` as input to your
|
||||||
|
generation, and connect the generated output → TensorLoopClose.processed.
|
||||||
|
Supports IMAGE, MASK, and LATENT types.
|
||||||
|
"""
|
||||||
|
MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TensorLoopOpen",
|
||||||
|
display_name="Tensor Loop Open",
|
||||||
|
category="looping/accumulation",
|
||||||
|
inputs=[
|
||||||
|
io.DynamicCombo.Input("mode", tooltip="Loop termination mode.", options=[
|
||||||
|
io.DynamicCombo.Option("iterations", [
|
||||||
|
io.Int.Input("iterations", default=4, min=0,
|
||||||
|
tooltip="Number of loop iterations. Set to 0 to skip looping and pass through initial_value."),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option("total_frames", [
|
||||||
|
io.Int.Input("total_frames", default=100, min=1,
|
||||||
|
tooltip="Target number of output frames. The loop continues until this many frames are accumulated."),
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.MatchType.Input("initial_value", template=cls.MATCHTYPE, optional=True,
|
||||||
|
tooltip="Optional value to use as `previous_value` on the first iteration."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.FlowControl.Output("flow_control"),
|
||||||
|
io.MatchType.Output(cls.MATCHTYPE, id="previous_value",
|
||||||
|
tooltip="The value from the previous iteration (or initial_value on first pass)."),
|
||||||
|
io.Int.Output("accumulated_count", tooltip="Number of items collected so far (0 on first iteration)."),
|
||||||
|
io.Int.Output("current_iteration", tooltip="Current iteration index (1-based)."),
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.unique_id],
|
||||||
|
accept_all_inputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, mode: dict, initial_value=None, **kwargs) -> io.NodeOutput:
|
||||||
|
unique_id = cls.hidden.unique_id
|
||||||
|
state = kwargs.get("initial_value0") # packed state dict or None on first pass
|
||||||
|
if state is not None:
|
||||||
|
count = state["count"]
|
||||||
|
total_frames_val = state["total_frames"]
|
||||||
|
remaining = state["remaining"]
|
||||||
|
accum = state["accum"]
|
||||||
|
previous_value = state["previous_value"]
|
||||||
|
open_node_id = state["open_node_id"]
|
||||||
|
else:
|
||||||
|
selected_mode = mode.get("mode", "iterations")
|
||||||
|
count = mode.get("iterations", 4) if selected_mode == "iterations" else 0
|
||||||
|
total_frames_val = mode.get("total_frames", 100) if selected_mode == "total_frames" else 0
|
||||||
|
remaining = count
|
||||||
|
accum = None
|
||||||
|
previous_value = initial_value
|
||||||
|
open_node_id = unique_id
|
||||||
|
|
||||||
|
accumulated_count = _accum_count(accum)
|
||||||
|
# In total_frames mode, count=0 and remaining goes negative each iteration.
|
||||||
|
# The math still produces correct 1-based iteration: 0-0+1=1, 0-(-1)+1=2, etc.
|
||||||
|
current_iteration = count - remaining + 1
|
||||||
|
if total_frames_val > 0:
|
||||||
|
comfy.utils.ProgressBar(total_frames_val, node_id=open_node_id).update_absolute(accumulated_count)
|
||||||
|
elif count > 0:
|
||||||
|
comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(count - remaining)
|
||||||
|
loop_state = {"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames_val}
|
||||||
|
return io.NodeOutput(loop_state, previous_value, accumulated_count, current_iteration)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorLoopClose(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Closes the loop started by TensorLoopOpen.
|
||||||
|
Connect:
|
||||||
|
- flow_control from TensorLoopOpen
|
||||||
|
- processed: the output generated this iteration
|
||||||
|
Supports IMAGE, MASK, and LATENT types.
|
||||||
|
"""
|
||||||
|
MATCHTYPE = io.MatchType.Template("data", allowed_types=[io.Image, io.Mask, io.Latent])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TensorLoopClose",
|
||||||
|
display_name="Tensor Loop Close",
|
||||||
|
category="looping/accumulation",
|
||||||
|
inputs=[
|
||||||
|
io.FlowControl.Input("flow_control", raw_link=True),
|
||||||
|
io.MatchType.Input("processed", template=cls.MATCHTYPE, raw_link=True, tooltip="Output generated this iteration."),
|
||||||
|
io.Boolean.Input("accumulate", default=True,
|
||||||
|
tooltip="When enabled, collects all iterations into a batch. When disabled, only outputs the final iteration's result."),
|
||||||
|
io.DynamicCombo.Input("overlap", tooltip="Remove or blend duplicate frames where consecutive iterations overlap.", options=[
|
||||||
|
io.DynamicCombo.Option("disabled", []),
|
||||||
|
io.DynamicCombo.Option("start", [
|
||||||
|
io.Int.Input("overlap_frames", default=8, min=1,
|
||||||
|
tooltip="Number of frames to trim. Use when the model re-generates context frames at the start of its output (most common for video continuation)."),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option("end", [
|
||||||
|
io.Int.Input("overlap_frames", default=8, min=1,
|
||||||
|
tooltip="Number of frames to trim. Use when the model generates look-ahead frames at the end of its output."),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option("fade_linear", [
|
||||||
|
io.Int.Input("overlap_frames", default=8, min=1,
|
||||||
|
tooltip="Number of frames to crossfade with a linear blend between consecutive iterations."),
|
||||||
|
]),
|
||||||
|
io.DynamicCombo.Option("fade_smooth", [
|
||||||
|
io.Int.Input("overlap_frames", default=8, min=1,
|
||||||
|
tooltip="Number of frames to crossfade with a smoothstep (ease in/out) blend between consecutive iterations."),
|
||||||
|
]),
|
||||||
|
]),
|
||||||
|
io.Boolean.Input("stop", optional=True, default=False, raw_link=True, force_input=True,
|
||||||
|
tooltip="Optional early stop signal from inside the loop body. When True, the loop stops after the current iteration regardless of remaining iterations or total_frames target."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.MatchType.Output(cls.MATCHTYPE, id="output", tooltip="Accumulated batch or final iteration result, depending on 'accumulate' setting."),
|
||||||
|
],
|
||||||
|
enable_expand=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, flow_control, processed, accumulate=True, overlap=None, stop=False) -> io.NodeOutput:
|
||||||
|
graph = GraphBuilder()
|
||||||
|
open_id = flow_control[0]
|
||||||
|
unpack = graph.node("_ImageAccumStateUnpack", loop_state=[open_id, 0])
|
||||||
|
# unpack: 0=remaining, 1=accum, 2=previous_value, 3=accumulated_count, 4=count, 5=open_node_id, 6=total_frames
|
||||||
|
sub = graph.node("_IntOperations", operation="subtract", a=unpack.out(0), b=1)
|
||||||
|
|
||||||
|
overlap_mode = "disabled"
|
||||||
|
overlap_frames = 0
|
||||||
|
if isinstance(overlap, dict):
|
||||||
|
overlap_mode = overlap.get("overlap", "disabled")
|
||||||
|
overlap_frames = overlap.get("overlap_frames", 0)
|
||||||
|
|
||||||
|
accum_out = unpack.out(1)
|
||||||
|
if accumulate:
|
||||||
|
to_accum = processed
|
||||||
|
if overlap_frames > 0 and overlap_mode != "disabled":
|
||||||
|
is_first = graph.node("_IntOperations", a=unpack.out(3), b=0, operation="==")
|
||||||
|
trimmed_start = graph.node("_BatchOps", batch=processed, operation="trim_start", amount=overlap_frames).out(0)
|
||||||
|
|
||||||
|
if overlap_mode == "start":
|
||||||
|
to_accum = trimmed_start
|
||||||
|
elif overlap_mode == "end":
|
||||||
|
trimmed_end = graph.node("_BatchOps", batch=processed, operation="trim_end", amount=overlap_frames).out(0)
|
||||||
|
trimmed_both = graph.node("_BatchOps", batch=trimmed_end, operation="trim_start", amount=overlap_frames).out(0)
|
||||||
|
to_accum = graph.node("_ConditionalSelect",
|
||||||
|
condition=is_first.out(1),
|
||||||
|
value_if_true=trimmed_both,
|
||||||
|
value_if_false=trimmed_end,
|
||||||
|
).out(0)
|
||||||
|
else:
|
||||||
|
# Fade: trim start on iter 1 only, keep full on subsequent for post-loop blend
|
||||||
|
to_accum = graph.node("_ConditionalSelect",
|
||||||
|
condition=is_first.out(1),
|
||||||
|
value_if_true=trimmed_start,
|
||||||
|
value_if_false=processed,
|
||||||
|
).out(0)
|
||||||
|
|
||||||
|
accum_out = graph.node("_AccumulateNode", to_add=to_accum, accumulation=accum_out).out(0)
|
||||||
|
|
||||||
|
# Disable total_frames when not accumulating to avoid infinite loops
|
||||||
|
pack = graph.node("_ImageAccumStatePack",
|
||||||
|
remaining=sub.out(0),
|
||||||
|
accum=accum_out,
|
||||||
|
previous_value=processed,
|
||||||
|
count=unpack.out(4),
|
||||||
|
open_node_id=unpack.out(5),
|
||||||
|
total_frames=unpack.out(6) if accumulate else 0,
|
||||||
|
prev_accumulated_count=unpack.out(3),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional early stop from loop body
|
||||||
|
if is_link(stop):
|
||||||
|
condition = graph.node("_ConditionalSelect",
|
||||||
|
condition=stop,
|
||||||
|
value_if_true=False,
|
||||||
|
value_if_false=pack.out(1),
|
||||||
|
).out(0)
|
||||||
|
else:
|
||||||
|
condition = pack.out(1)
|
||||||
|
|
||||||
|
while_close = graph.node(
|
||||||
|
"_WhileLoopClose",
|
||||||
|
flow_control=flow_control,
|
||||||
|
condition=condition,
|
||||||
|
initial_value0=pack.out(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
final_unpack = graph.node("_ImageAccumStateUnpack", loop_state=while_close.out(0))
|
||||||
|
if accumulate:
|
||||||
|
if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0:
|
||||||
|
result = graph.node("_AccumulationToImageBatch",
|
||||||
|
accumulation=final_unpack.out(1),
|
||||||
|
overlap_frames=overlap_frames,
|
||||||
|
overlap_mode=overlap_mode,
|
||||||
|
).out(0)
|
||||||
|
else:
|
||||||
|
result = graph.node("_AccumulationToImageBatch", accumulation=final_unpack.out(1)).out(0)
|
||||||
|
|
||||||
|
# End mode: re-append the tail that was trimmed from the last iteration
|
||||||
|
if overlap_mode == "end" and overlap_frames > 0:
|
||||||
|
tail = graph.node("_BatchOps", batch=final_unpack.out(2), operation="keep_end", amount=overlap_frames).out(0)
|
||||||
|
result = graph.node("_BatchOps", batch=result, operation="concat", batch_b=tail).out(0)
|
||||||
|
|
||||||
|
result = graph.node("_BatchOps", batch=result, operation="max_count", amount=final_unpack.out(6)).out(0)
|
||||||
|
else:
|
||||||
|
result = final_unpack.out(2)
|
||||||
|
|
||||||
|
# Bypass: when iterations==0 AND total_frames==0, return initial_value directly
|
||||||
|
count_is_zero = graph.node("_IntOperations", a=unpack.out(4), b=0, operation="==")
|
||||||
|
tf_is_zero = graph.node("_IntOperations", a=unpack.out(6), b=0, operation="==")
|
||||||
|
both_zero = graph.node("_IntOperations", a=count_is_zero.out(0), b=tf_is_zero.out(0), operation="multiply")
|
||||||
|
result = graph.node("_ConditionalSelect",
|
||||||
|
condition=both_zero.out(1),
|
||||||
|
value_if_true=unpack.out(2),
|
||||||
|
value_if_false=result,
|
||||||
|
).out(0)
|
||||||
|
|
||||||
|
return io.NodeOutput(result, expand=graph.finalize())
|
||||||
|
|
||||||
|
# Internal helper nodes — dev only, hidden from the node menu.
|
||||||
|
|
||||||
|
class _AccumulateNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_AccumulateNode",
|
||||||
|
display_name="Accumulate",
|
||||||
|
category="looping/accumulation",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.AnyType.Input("to_add"),
|
||||||
|
io.Accumulation.Input("accumulation", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Accumulation.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, to_add, accumulation=None) -> io.NodeOutput:
|
||||||
|
if accumulation is None:
|
||||||
|
value = [to_add]
|
||||||
|
else:
|
||||||
|
value = accumulation["accum"] + [to_add]
|
||||||
|
return io.NodeOutput({"accum": value})
|
||||||
|
|
||||||
|
class _WhileLoopOpen(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_WhileLoopOpen",
|
||||||
|
display_name="While Loop Open",
|
||||||
|
category="looping",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.Boolean.Input("condition", default=True),
|
||||||
|
*[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)],
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.FlowControl.Output("flow_control", display_name="FLOW_CONTROL"),
|
||||||
|
*[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)],
|
||||||
|
],
|
||||||
|
accept_all_inputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, condition: bool, **kwargs) -> io.NodeOutput:
|
||||||
|
values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)]
|
||||||
|
return io.NodeOutput("stub", *values)
|
||||||
|
|
||||||
|
class _WhileLoopClose(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_WhileLoopClose",
|
||||||
|
display_name="While Loop Close",
|
||||||
|
category="looping",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.FlowControl.Input("flow_control", raw_link=True),
|
||||||
|
io.Boolean.Input("condition", force_input=True),
|
||||||
|
*[io.AnyType.Input(f"initial_value{i}", optional=True) for i in range(NUM_FLOW_SOCKETS)],
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
*[io.AnyType.Output(f"value{i}") for i in range(NUM_FLOW_SOCKETS)],
|
||||||
|
],
|
||||||
|
hidden=[io.Hidden.dynprompt, io.Hidden.unique_id],
|
||||||
|
enable_expand=True,
|
||||||
|
accept_all_inputs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _explore_dependencies(node_id, dynprompt, upstream):
|
||||||
|
node_info = dynprompt.get_node(node_id)
|
||||||
|
if "inputs" not in node_info:
|
||||||
|
return
|
||||||
|
for k, v in node_info["inputs"].items():
|
||||||
|
if is_link(v):
|
||||||
|
parent_id = v[0]
|
||||||
|
if parent_id not in upstream:
|
||||||
|
upstream[parent_id] = []
|
||||||
|
_WhileLoopClose._explore_dependencies(parent_id, dynprompt, upstream)
|
||||||
|
upstream[parent_id].append(node_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_contained(node_id, upstream, contained):
|
||||||
|
if node_id not in upstream:
|
||||||
|
return
|
||||||
|
for child_id in upstream[node_id]:
|
||||||
|
if child_id not in contained:
|
||||||
|
contained[child_id] = True
|
||||||
|
_WhileLoopClose._collect_contained(child_id, upstream, contained)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, flow_control, condition: bool, **kwargs) -> io.NodeOutput:
|
||||||
|
dynprompt = cls.hidden.dynprompt
|
||||||
|
unique_id = cls.hidden.unique_id
|
||||||
|
values = [kwargs.get(f"initial_value{i}", None) for i in range(NUM_FLOW_SOCKETS)]
|
||||||
|
|
||||||
|
if not condition: # Done with the loop — return current values
|
||||||
|
return io.NodeOutput(*values)
|
||||||
|
|
||||||
|
# Build the graph expansion for the next loop iteration
|
||||||
|
upstream = {}
|
||||||
|
cls._explore_dependencies(unique_id, dynprompt, upstream)
|
||||||
|
|
||||||
|
contained = {}
|
||||||
|
open_node = flow_control[0]
|
||||||
|
cls._collect_contained(open_node, upstream, contained)
|
||||||
|
contained[unique_id] = True
|
||||||
|
contained[open_node] = True
|
||||||
|
|
||||||
|
# Use "Recurse" for this node's clone to avoid exponential name growth
|
||||||
|
graph = GraphBuilder()
|
||||||
|
for node_id in contained:
|
||||||
|
original_node = dynprompt.get_node(node_id)
|
||||||
|
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
||||||
|
node.set_override_display_id(node_id)
|
||||||
|
for node_id in contained:
|
||||||
|
original_node = dynprompt.get_node(node_id)
|
||||||
|
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
||||||
|
for k, v in original_node["inputs"].items():
|
||||||
|
if is_link(v) and v[0] in contained:
|
||||||
|
parent = graph.lookup_node(v[0])
|
||||||
|
node.set_input(k, parent.out(v[1]))
|
||||||
|
else:
|
||||||
|
node.set_input(k, v)
|
||||||
|
|
||||||
|
new_open = graph.lookup_node(open_node)
|
||||||
|
for i in range(NUM_FLOW_SOCKETS):
|
||||||
|
new_open.set_input(f"initial_value{i}", values[i])
|
||||||
|
|
||||||
|
my_clone = graph.lookup_node("Recurse")
|
||||||
|
result = tuple(my_clone.out(x) for x in range(NUM_FLOW_SOCKETS))
|
||||||
|
return io.NodeOutput(*result, expand=graph.finalize())
|
||||||
|
|
||||||
|
|
||||||
|
class _IntOperations(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_IntOperations",
|
||||||
|
display_name="Int Operations",
|
||||||
|
category="looping/logic",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("a", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1),
|
||||||
|
io.Int.Input("b", default=0, min=-0xffffffffffffffff, max=0xffffffffffffffff, step=1),
|
||||||
|
io.Combo.Input("operation", options=[
|
||||||
|
"add", "subtract", "multiply", "divide", "modulo", "power",
|
||||||
|
"==", "!=", "<", ">", "<=", ">=",
|
||||||
|
]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Int.Output(),
|
||||||
|
io.Boolean.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
OPS = {
|
||||||
|
"add": lambda a, b: a + b, "subtract": lambda a, b: a - b,
|
||||||
|
"multiply": lambda a, b: a * b, "divide": lambda a, b: a // b if b else 0,
|
||||||
|
"modulo": lambda a, b: a % b if b else 0, "power": lambda a, b: a ** b,
|
||||||
|
"==": lambda a, b: a == b, "!=": lambda a, b: a != b,
|
||||||
|
"<": lambda a, b: a < b, ">": lambda a, b: a > b,
|
||||||
|
"<=": lambda a, b: a <= b, ">=": lambda a, b: a >= b,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, a: int, b: int, operation: str) -> io.NodeOutput:
|
||||||
|
result = cls.OPS[operation](a, b)
|
||||||
|
return io.NodeOutput(int(result), bool(result))
|
||||||
|
|
||||||
|
def _get_frame_count(x, dim):
|
||||||
|
"""Return the size along the given dimension."""
|
||||||
|
if isinstance(x, dict) and "samples" in x:
|
||||||
|
s = x["samples"]
|
||||||
|
return s.tensors[0].shape[dim] if _is_nested(s) else s.shape[dim]
|
||||||
|
return x.shape[dim]
|
||||||
|
|
||||||
|
def _get_cat_dim(x):
|
||||||
|
"""Return the concatenation dimension for a tensor: dim 2 for 5D video, dim 0 for everything else."""
|
||||||
|
if isinstance(x, dict):
|
||||||
|
s = x["samples"]
|
||||||
|
if _is_nested(s):
|
||||||
|
return 2 # NestedTensor sub-tensors use temporal dim
|
||||||
|
return 2 if s.ndim == 5 else 0
|
||||||
|
return 0 # IMAGE/MASK: batch dim
|
||||||
|
|
||||||
|
def _slice_tensor(x, dim, start=None, end=None):
|
||||||
|
"""Slice a tensor, latent dict, or NestedTensor along the given dim."""
|
||||||
|
if isinstance(x, dict) and "samples" in x:
|
||||||
|
result = x.copy()
|
||||||
|
samples = x["samples"]
|
||||||
|
if _is_nested(samples):
|
||||||
|
sliced = []
|
||||||
|
for t in samples.tensors:
|
||||||
|
sliced.append(t[(slice(None),) * dim + (slice(start, end),)])
|
||||||
|
result["samples"] = NestedTensor(sliced)
|
||||||
|
mask = x.get("noise_mask")
|
||||||
|
if mask is not None and _is_nested(mask):
|
||||||
|
mask_sliced = []
|
||||||
|
for t in mask.tensors:
|
||||||
|
mask_sliced.append(t[(slice(None),) * dim + (slice(start, end),)])
|
||||||
|
result["noise_mask"] = NestedTensor(mask_sliced)
|
||||||
|
else:
|
||||||
|
result["samples"] = samples[(slice(None),) * dim + (slice(start, end),)]
|
||||||
|
mask = x.get("noise_mask")
|
||||||
|
if mask is not None and mask.ndim == samples.ndim:
|
||||||
|
result["noise_mask"] = mask[(slice(None),) * dim + (slice(start, end),)]
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return x[(slice(None),) * dim + (slice(start, end),)]
|
||||||
|
|
||||||
|
def _concat_tensor(a, b, dim):
|
||||||
|
"""Concatenate two tensors, latent dicts, or NestedTensors along the given dim."""
|
||||||
|
if isinstance(a, dict) and "samples" in a:
|
||||||
|
result = a.copy()
|
||||||
|
sa, sb = a["samples"], b["samples"]
|
||||||
|
if _is_nested(sa):
|
||||||
|
result["samples"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(sa.tensors, sb.tensors)])
|
||||||
|
ma, mb = a.get("noise_mask"), b.get("noise_mask")
|
||||||
|
if ma is not None and mb is not None and _is_nested(ma):
|
||||||
|
result["noise_mask"] = NestedTensor([torch.cat([ta, tb], dim=dim) for ta, tb in zip(ma.tensors, mb.tensors)])
|
||||||
|
else:
|
||||||
|
result["samples"] = torch.cat([sa, sb], dim=dim)
|
||||||
|
ma, mb = a.get("noise_mask"), b.get("noise_mask")
|
||||||
|
if ma is not None and mb is not None and ma.ndim == sa.ndim:
|
||||||
|
result["noise_mask"] = torch.cat([ma, mb], dim=dim)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return torch.cat([a, b], dim=dim)
|
||||||
|
|
||||||
|
def _blend_overlap(items, overlap_frames, mode):
|
||||||
|
"""Concatenate items with crossfade blending in the overlap regions."""
|
||||||
|
if not items:
|
||||||
|
return None
|
||||||
|
dim = _get_cat_dim(items[0])
|
||||||
|
# Clamp overlap to the smallest item size to avoid slicing beyond bounds
|
||||||
|
min_frames = min(_get_frame_count(item, dim) for item in items)
|
||||||
|
overlap_frames = min(overlap_frames, min_frames - 1) if min_frames > 1 else 0
|
||||||
|
if overlap_frames <= 0:
|
||||||
|
# Nothing to blend — fall through to simple concat
|
||||||
|
return _concat_tensor(items[0], items[1], dim) if len(items) == 2 else items[0]
|
||||||
|
t = torch.linspace(0, 1, overlap_frames)
|
||||||
|
if mode == "fade_smooth":
|
||||||
|
t = t * t * (3 - 2 * t)
|
||||||
|
|
||||||
|
result = items[0]
|
||||||
|
for i in range(1, len(items)):
|
||||||
|
prev_tail = _slice_tensor(result, dim, start=-overlap_frames)
|
||||||
|
curr_head = _slice_tensor(items[i], dim, end=overlap_frames)
|
||||||
|
curr_rest = _slice_tensor(items[i], dim, start=overlap_frames)
|
||||||
|
result_base = _slice_tensor(result, dim, end=-overlap_frames)
|
||||||
|
|
||||||
|
# Lerp: blended = prev_tail * (1-w) + curr_head * w
|
||||||
|
if isinstance(prev_tail, dict):
|
||||||
|
blended = prev_tail.copy()
|
||||||
|
ps, cs = prev_tail["samples"], curr_head["samples"]
|
||||||
|
if _is_nested(ps):
|
||||||
|
blended_tensors = []
|
||||||
|
for pt, ch in zip(ps.tensors, cs.tensors):
|
||||||
|
w = t.to(pt.device).reshape([1] * dim + [overlap_frames] + [1] * (pt.ndim - dim - 1))
|
||||||
|
blended_tensors.append(pt * (1 - w) + ch * w)
|
||||||
|
blended["samples"] = NestedTensor(blended_tensors)
|
||||||
|
else:
|
||||||
|
w = t.to(ps.device).reshape([1] * dim + [overlap_frames] + [1] * (ps.ndim - dim - 1))
|
||||||
|
blended["samples"] = ps * (1 - w) + cs * w
|
||||||
|
else:
|
||||||
|
w = t.to(prev_tail.device).reshape([1] * dim + [overlap_frames] + [1] * (prev_tail.ndim - dim - 1))
|
||||||
|
blended = prev_tail * (1 - w) + curr_head * w
|
||||||
|
|
||||||
|
result = _concat_tensor(_concat_tensor(result_base, blended, dim), curr_rest, dim)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class _AccumulationToImageBatch(io.ComfyNode):
|
||||||
|
"""Concatenates an ACCUMULATION of IMAGE/MASK tensors or LATENT dicts into a single batch."""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_AccumulationToImageBatch",
|
||||||
|
display_name="Accumulation to Batch",
|
||||||
|
category="looping/accumulation",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.Accumulation.Input("accumulation"),
|
||||||
|
io.Int.Input("overlap_frames", default=0, min=0),
|
||||||
|
io.Combo.Input("overlap_mode", options=["disabled", "fade_linear", "fade_smooth"], default="disabled"),
|
||||||
|
],
|
||||||
|
outputs=[io.AnyType.Output("result")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, accumulation, overlap_frames=0, overlap_mode="disabled") -> io.NodeOutput:
|
||||||
|
items = accumulation["accum"]
|
||||||
|
if not items:
|
||||||
|
return io.NodeOutput(None)
|
||||||
|
|
||||||
|
# Fade modes: blend overlap regions between consecutive items
|
||||||
|
if overlap_mode in ("fade_linear", "fade_smooth") and overlap_frames > 0 and len(items) > 1:
|
||||||
|
return io.NodeOutput(_blend_overlap(items, overlap_frames, overlap_mode))
|
||||||
|
|
||||||
|
# Standard concatenation (no overlap or start/end trim was handled per-iteration)
|
||||||
|
if isinstance(items[0], dict):
|
||||||
|
samples = items[0]["samples"]
|
||||||
|
if _is_nested(samples):
|
||||||
|
num_sub = len(samples.tensors)
|
||||||
|
catted = []
|
||||||
|
for i in range(num_sub):
|
||||||
|
catted.append(torch.cat([item["samples"].tensors[i] for item in items], dim=2))
|
||||||
|
result = items[0].copy()
|
||||||
|
result["samples"] = NestedTensor(catted)
|
||||||
|
result.pop("noise_mask", None)
|
||||||
|
masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None]
|
||||||
|
if masks and _is_nested(masks[0]):
|
||||||
|
mask_catted = []
|
||||||
|
for i in range(len(masks[0].tensors)):
|
||||||
|
mask_catted.append(torch.cat([m.tensors[i] for m in masks], dim=2))
|
||||||
|
result["noise_mask"] = NestedTensor(mask_catted)
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
elif samples.ndim == 5:
|
||||||
|
result = items[0].copy()
|
||||||
|
result["samples"] = torch.cat([item["samples"] for item in items], dim=2)
|
||||||
|
result.pop("noise_mask", None)
|
||||||
|
masks = [item.get("noise_mask") for item in items if item.get("noise_mask") is not None]
|
||||||
|
if masks and masks[0].ndim == 5:
|
||||||
|
result["noise_mask"] = torch.cat(masks, dim=2)
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
else:
|
||||||
|
# Image latent — batch along dim 0
|
||||||
|
result = items[0].copy()
|
||||||
|
result["samples"] = torch.cat([item["samples"] for item in items], dim=0)
|
||||||
|
return io.NodeOutput(result)
|
||||||
|
else:
|
||||||
|
return io.NodeOutput(torch.cat(items, dim=0))
|
||||||
|
|
||||||
|
|
||||||
|
class _ConditionalSelect(io.ComfyNode):
|
||||||
|
"""Returns value_if_true when condition is True, value_if_false otherwise.
|
||||||
|
Uses lazy evaluation so only the selected branch is resolved."""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_ConditionalSelect",
|
||||||
|
display_name="Conditional Select",
|
||||||
|
category="looping/logic",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.Boolean.Input("condition"),
|
||||||
|
io.AnyType.Input("value_if_true", lazy=True),
|
||||||
|
io.AnyType.Input("value_if_false", lazy=True),
|
||||||
|
],
|
||||||
|
outputs=[io.AnyType.Output("result")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_lazy_status(cls, condition, value_if_true=None, value_if_false=None) -> list[str]:
|
||||||
|
# Unevaluated lazy inputs arrive as None
|
||||||
|
if condition and value_if_true is None:
|
||||||
|
return ["value_if_true"]
|
||||||
|
if not condition and value_if_false is None:
|
||||||
|
return ["value_if_false"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, condition, **kwargs) -> io.NodeOutput:
|
||||||
|
selected = kwargs.get("value_if_true") if condition else kwargs.get("value_if_false")
|
||||||
|
return io.NodeOutput(selected)
|
||||||
|
|
||||||
|
|
||||||
|
class _ImageAccumStatePack(io.ComfyNode):
|
||||||
|
"""Packs loop state into a single dict for TensorLoopOpen's initial_value0."""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_ImageAccumStatePack",
|
||||||
|
display_name="Image Accum State Pack",
|
||||||
|
category="looping/accumulation",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.AnyType.Input("remaining"),
|
||||||
|
io.Accumulation.Input("accum", optional=True),
|
||||||
|
io.AnyType.Input("previous_value"),
|
||||||
|
io.AnyType.Input("count"),
|
||||||
|
io.AnyType.Input("open_node_id"),
|
||||||
|
io.AnyType.Input("total_frames"),
|
||||||
|
io.Int.Input("prev_accumulated_count", default=0),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.AnyType.Output("loop_state"),
|
||||||
|
io.Boolean.Output("should_continue"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, remaining, accum, previous_value, count, open_node_id, total_frames, prev_accumulated_count=0) -> io.NodeOutput:
|
||||||
|
accumulated_count = _accum_count(accum)
|
||||||
|
|
||||||
|
if total_frames > 0:
|
||||||
|
should_continue = accumulated_count < total_frames
|
||||||
|
# Bail if the last iteration added nothing — the loop would never reach the target
|
||||||
|
if accumulated_count == prev_accumulated_count:
|
||||||
|
should_continue = False
|
||||||
|
comfy.utils.ProgressBar(total_frames, node_id=open_node_id).update_absolute(accumulated_count)
|
||||||
|
else:
|
||||||
|
should_continue = remaining > 0
|
||||||
|
current_iteration = count - remaining
|
||||||
|
comfy.utils.ProgressBar(count, node_id=open_node_id).update_absolute(current_iteration)
|
||||||
|
|
||||||
|
return io.NodeOutput(
|
||||||
|
{"remaining": remaining, "accum": accum, "previous_value": previous_value, "count": count, "open_node_id": open_node_id, "total_frames": total_frames},
|
||||||
|
should_continue,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ImageAccumStateUnpack(io.ComfyNode):
|
||||||
|
"""Unpacks loop_state from TensorLoopOpen."""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_ImageAccumStateUnpack",
|
||||||
|
display_name="Image Accum State Unpack",
|
||||||
|
category="looping/accumulation",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[io.AnyType.Input("loop_state")],
|
||||||
|
outputs=[
|
||||||
|
io.Int.Output("remaining"),
|
||||||
|
io.Accumulation.Output("accumulation"),
|
||||||
|
io.AnyType.Output("previous_value"),
|
||||||
|
io.Int.Output("accumulated_count"),
|
||||||
|
io.Int.Output("count"),
|
||||||
|
io.AnyType.Output("open_node_id"),
|
||||||
|
io.Int.Output("total_frames"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, loop_state) -> io.NodeOutput:
|
||||||
|
remaining = loop_state["remaining"]
|
||||||
|
accum = loop_state["accum"]
|
||||||
|
previous_value = loop_state["previous_value"]
|
||||||
|
count = loop_state.get("count", 0)
|
||||||
|
open_node_id = loop_state.get("open_node_id")
|
||||||
|
total_frames = loop_state.get("total_frames", 0)
|
||||||
|
accumulated_count = _accum_count(accum)
|
||||||
|
return io.NodeOutput(remaining, accum, previous_value, accumulated_count, count, open_node_id, total_frames)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class _BatchOps(io.ComfyNode):
|
||||||
|
"""Batch slicing, trimming, and concatenation operations."""
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="_BatchOps",
|
||||||
|
display_name="Batch Ops",
|
||||||
|
category="looping/accumulation",
|
||||||
|
is_dev_only=True,
|
||||||
|
inputs=[
|
||||||
|
io.AnyType.Input("batch"),
|
||||||
|
io.Combo.Input("operation", options=["max_count", "trim_start", "trim_end", "keep_end", "concat"]),
|
||||||
|
io.Int.Input("amount", default=0, min=0),
|
||||||
|
io.AnyType.Input("batch_b", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[io.AnyType.Output("output")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, batch, operation, amount=0, batch_b=None) -> io.NodeOutput:
|
||||||
|
dim = _get_cat_dim(batch)
|
||||||
|
if operation == "concat" and batch_b is not None:
|
||||||
|
return io.NodeOutput(_concat_tensor(batch, batch_b, dim))
|
||||||
|
if amount <= 0:
|
||||||
|
return io.NodeOutput(batch)
|
||||||
|
# Clamp to avoid slicing beyond the tensor size
|
||||||
|
total = _get_frame_count(batch, dim)
|
||||||
|
amount = min(amount, total)
|
||||||
|
if operation == "max_count":
|
||||||
|
return io.NodeOutput(_slice_tensor(batch, dim, end=amount))
|
||||||
|
elif operation == "trim_start":
|
||||||
|
return io.NodeOutput(_slice_tensor(batch, dim, start=amount))
|
||||||
|
elif operation == "trim_end":
|
||||||
|
return io.NodeOutput(_slice_tensor(batch, dim, end=-amount) if amount < total else batch)
|
||||||
|
elif operation == "keep_end":
|
||||||
|
return io.NodeOutput(_slice_tensor(batch, dim, start=-amount))
|
||||||
|
return io.NodeOutput(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class LoopExtension(ComfyExtension):
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
TensorLoopOpen,
|
||||||
|
TensorLoopClose,
|
||||||
|
_WhileLoopOpen,
|
||||||
|
_WhileLoopClose,
|
||||||
|
_AccumulateNode,
|
||||||
|
_IntOperations,
|
||||||
|
_AccumulationToImageBatch,
|
||||||
|
_ConditionalSelect,
|
||||||
|
_ImageAccumStateUnpack,
|
||||||
|
_ImageAccumStatePack,
|
||||||
|
_BatchOps,
|
||||||
|
]
|
||||||
|
|
||||||
|
def comfy_entrypoint():
|
||||||
|
return LoopExtension()
|
||||||
@ -1326,6 +1326,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context.", advanced=True),
|
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context.", advanced=True),
|
||||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
io.Image.Input("previous_frames", optional=True),
|
io.Image.Input("previous_frames", optional=True),
|
||||||
|
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in the previous_frames input.")
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(display_name="model"),
|
io.Model.Output(display_name="model"),
|
||||||
@ -1338,7 +1339,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None, video_frame_offset=0) -> io.NodeOutput:
|
||||||
|
|
||||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||||
raise ValueError("Not enough previous frames provided.")
|
raise ValueError("Not enough previous frames provided.")
|
||||||
@ -1421,11 +1422,13 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
# when extending from previous frames
|
# when extending from previous frames
|
||||||
if previous_frames is not None:
|
if previous_frames is not None:
|
||||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
if video_frame_offset is not None and video_frame_offset > 0:
|
||||||
|
frame_offset = video_frame_offset - motion_frame_count
|
||||||
|
else:
|
||||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||||
|
|
||||||
audio_start = frame_offset
|
audio_start = frame_offset
|
||||||
audio_end = audio_start + length
|
audio_end = audio_start + length
|
||||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
|
||||||
|
|
||||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||||
trim_image = motion_frame_count
|
trim_image = motion_frame_count
|
||||||
@ -1434,6 +1437,8 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
audio_end = length
|
audio_end = length
|
||||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||||
|
|
||||||
|
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||||
|
|
||||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user