Remove looping functionality, keep extension functionality

This commit is contained in:
kijai 2025-11-03 21:12:02 +02:00
parent 8d62661a9f
commit d53e62913d
2 changed files with 55 additions and 164 deletions

View File

@ -1,10 +1,7 @@
import torch
from einops import rearrange, repeat
import math
import comfy
from comfy.ldm.modules.attention import optimized_attention
import logging
import latent_preview
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks):
@ -408,12 +405,12 @@ class MultiTalkCrossAttnPatch:
self.ref_target_masks = ref_target_masks
def __call__(self, kwargs):
x = kwargs["x"]
block_idx = kwargs.get("block_idx", 0)
transformer_options = kwargs.get("transformer_options", {})
block_idx = transformer_options.get("block_idx", None)
if block_idx is None:
return torch.zeros_like(x)
x = kwargs["x"]
transformer_options = kwargs.get("transformer_options", {})
audio_embeds = transformer_options.get("audio_embeds")
x_ref_attn_map = None
@ -440,152 +437,35 @@ class MultiTalkApplyModelWrapper:
return samples
class InfiniteTalkOuterSampleLoopingWrapper:
def __init__(self, init_previous_frames, encoded_audio, model_patch, audio_scale, max_frames, frame_window_size, motion_frame_count=9, vae=None, ref_target_masks=None):
self.init_previous_frames = init_previous_frames
self.encoded_audio = encoded_audio
self.total_audio_frames = encoded_audio[0].shape[0]
self.max_frames = max_frames
self.frame_window_size = frame_window_size
self.latent_window_size = (frame_window_size - 1) // 4 + 1
class InfiniteTalkOuterSampleWrapper:
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
self.motion_frames_latent = motion_frames_latent
self.model_patch = model_patch
self.audio_scale = audio_scale
self.motion_frame_count = motion_frame_count
self.vae = vae
self.ref_target_masks = ref_target_masks
def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, **kwargs):
# init variables
previous_frames = motion_frames_latent = None
init_from_cond = False
frame_offset = audio_start = latent_frame_offset = latent_start_idx = 0
audio_end = self.frame_window_size
latent_end_idx = self.latent_window_size
decoded_results = []
self.is_extend = is_extend
def __call__(self, executor, *args, **kwargs):
model_patcher = executor.class_obj.model_patcher
model_options = executor.class_obj.model_options
process_latent_in = model_patcher.model.process_latent_in
dtype = model_patcher.model_dtype()
# when extending from previous frames
if self.init_previous_frames is not None:
decoded_results.append(self.init_previous_frames.unsqueeze(0))
previous_frames = self.init_previous_frames # should we grow the results here or rely on using batch image nodes in the workflow?
if previous_frames.shape[0] < self.motion_frame_count:
previous_frames = torch.cat([previous_frames[:1].repeat(self.motion_frame_count - previous_frames.shape[0], 1, 1, 1), previous_frames], dim=0)
motion_frames = previous_frames[-self.motion_frame_count:]
frame_offset = previous_frames.shape[0] - self.motion_frame_count
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
if self.motion_frames_latent is not None:
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
# add/replace current cross-attention patch to model options
model_options["transformer_options"].setdefault("patches", {}).setdefault("cross_attn", []).append(
MultiTalkCrossAttnPatch(self.model_patch, self.audio_scale, ref_target_masks=self.ref_target_masks)
)
# run the sampling process
result = executor(*args, **kwargs)
frames_needed = math.ceil(min(self.max_frames, self.total_audio_frames) / 81) * 81
estimated_iterations = frames_needed // (self.frame_window_size - self.motion_frame_count)
total_steps = (sigmas.shape[-1] - 1) * estimated_iterations
logging.info(f"InfiniteTalk estimated loop iterations: {estimated_iterations}, Total steps: {total_steps}")
# custom previewer callback for full loop progress bar
pbar = comfy.utils.ProgressBar(total_steps)
previewer = latent_preview.get_previewer(model_patcher.load_device, model_patcher.model.latent_format)
def custom_callback(step, x0, x, total_steps):
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0)
pbar.update_absolute(pbar.current+1, preview=preview_bytes)
# outer loop start for multiple frame windows
for i in range(estimated_iterations):
# first frame to InfinityTalk always has to be noise free encoded image
# if no previous samples provided, try to get I2V cond latent from positive cond
if previous_frames is None:
concat_latent_image = executor.class_obj.conds["positive"][0].get("concat_latent_image", None)
if concat_latent_image is not None:
motion_frames_latent = concat_latent_image[:, :, :1]
overlap = 1
init_from_cond = True
# else, use previous samples' last frames as first frame
else:
audio_start = frame_offset
audio_end = audio_start + self.frame_window_size
latent_start_idx = latent_frame_offset
latent_end_idx = latent_start_idx + self.latent_window_size
if len(motion_frames.shape) == 5:
motion_frames = motion_frames.squeeze(0)
spacial_compression = self.vae.spacial_compression_encode()
if (motion_frames.shape[-3], motion_frames.shape[-2]) != (noise.shape[-2] * spacial_compression, noise.shape[-1] * spacial_compression):
motion_frames = comfy.utils.common_upscale(
motion_frames.movedim(-1, 1),
noise.shape[-1] * spacial_compression, noise.shape[-2] * spacial_compression,
"bilinear", "center")
motion_frames_latent = self.vae.encode(motion_frames)
overlap = motion_frames_latent.shape[2]
audio_embed = project_audio_features(self.model_patch.model.audio_proj, self.encoded_audio, audio_start, audio_end).to(dtype)
model_options["transformer_options"]["audio_embeds"] = audio_embed
# model input first latents need to always be replaced on every step
if motion_frames_latent is not None:
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(motion_frames_latent))]
# Slice possible encoded latent_image for vid2vid
if latent_image is not None and torch.count_nonzero(latent_image) > 0:
# Check if we have enough latents
if latent_end_idx > latent_image.shape[2]:
# This window needs more frames - pad the latent_image at the end
pad_length = latent_end_idx - latent_image.shape[2]
last_frame = latent_image[:, :, -1:].repeat(1, 1, pad_length, 1, 1)
latent_image = torch.cat([latent_image, last_frame], dim=2)
new_noise_frames = torch.randn_like(latent_image[:, :, -pad_length:], device=noise.device, dtype=noise.dtype)
noise = torch.cat([noise, new_noise_frames], dim=2)
noise = noise[:, :, latent_start_idx:latent_end_idx]
latent_image = latent_image[:, :, latent_start_idx:latent_end_idx]
#if denoise_mask is not None: # todo: check if denoise mask needs adjustment for latent_image changes
# run the sampling process
result = executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=custom_callback, disable_pbar=False, seed=seed, **kwargs)
#insert motion frames before decoding
if previous_frames is not None and not init_from_cond:
result = torch.cat([motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
previous_frames = self.vae.decode(result)
motion_frames = previous_frames[:, -self.motion_frame_count:]
# Track frame progress
new_frame_count = previous_frames.shape[1] - self.motion_frame_count
frame_offset += new_frame_count
motion_latent_count = (self.motion_frame_count - 1) // 4 + 1 if self.motion_frame_count > 0 else 0
new_latent_count = self.latent_window_size - motion_latent_count
latent_frame_offset += new_latent_count
if init_from_cond:
decoded_results.append(previous_frames)
init_from_cond = False
else:
decoded_results.append(previous_frames[:, self.motion_frame_count:])
return torch.cat(decoded_results, dim=1)
# insert motion frames before decoding
if self.is_extend:
overlap = self.motion_frames_latent.shape[2]
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
return result
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
if self.init_previous_frames is not None:
self.init_previous_frames = self.init_previous_frames.to(device_or_dtype)
if self.encoded_audio is not None:
self.encoded_audio = [ea.to(device_or_dtype) for ea in self.encoded_audio]
if self.ref_target_masks is not None:
self.ref_target_masks = self.ref_target_masks.to(device_or_dtype)
if self.motion_frames_latent is not None:
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
return self

View File

@ -11,6 +11,7 @@ import numpy as np
from typing import Tuple
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import logging
class WanImageToVideo(io.ComfyNode):
@classmethod
@ -1288,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
return io.NodeOutput(out_latent)
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleLoopingWrapper
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features
class WanInfiniteTalkToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -1310,7 +1311,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
io.Int.Input("frame_window_size", default=81, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames to generate in one window."),
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
io.Image.Input("previous_frames", optional=True),
@ -1320,15 +1320,16 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Int.Output(display_name="trim_image"),
],
)
@classmethod
def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, frame_window_size,
def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
if frame_window_size > length:
frame_window_size = length
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
raise ValueError("Not enough previous frames provided.")
if audio_encoder_output_2 is not None:
if mask_1 is None or mask_2 is None:
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
@ -1339,10 +1340,10 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
raise ValueError("Second audio encoder output must be provided if two masks are used.")
ref_masks = torch.cat([mask_1, mask_2])
latent = torch.zeros([1, 16, ((frame_window_size - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:frame_window_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((frame_window_size, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
@ -1399,30 +1400,40 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
init_previous_frames = None
# when extending from previous frames
if previous_frames is not None:
init_previous_frames = previous_frames[:, :, :, :3]
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
frame_offset = previous_frames.shape[0] - motion_frame_count
audio_start = frame_offset
audio_end = audio_start + length
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
trim_image = motion_frame_count
else:
audio_start = trim_image = 0
audio_end = length
motion_frames_latent = concat_latent_image[:, :, :1]
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
# add outer sample wrapper
model_patched.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
"infinite_talk_outer_sample",
InfiniteTalkOuterSampleLoopingWrapper(
init_previous_frames,
encoded_audio_list,
InfiniteTalkOuterSampleWrapper(
motion_frames_latent,
model_patch,
audio_scale,
length,
frame_window_size,
motion_frame_count,
vae=vae,
ref_target_masks=token_ref_target_masks)
)
is_extend=previous_frames is not None,
))
# add cross-attention patch
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn")
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(model_patched, positive, negative, out_latent)
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanExtension(ComfyExtension):