From d53e62913d6f1fead857f3bfa9afcf25152e5e5b Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 3 Nov 2025 21:12:02 +0200 Subject: [PATCH] Remove looping functionality, keep extension functionality --- comfy/ldm/wan/model_multitalk.py | 164 +++++-------------------------- comfy_extras/nodes_wan.py | 55 ++++++----- 2 files changed, 55 insertions(+), 164 deletions(-) diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py index 864cae82c..d6bb64672 100644 --- a/comfy/ldm/wan/model_multitalk.py +++ b/comfy/ldm/wan/model_multitalk.py @@ -1,10 +1,7 @@ import torch from einops import rearrange, repeat -import math import comfy from comfy.ldm.modules.attention import optimized_attention -import logging -import latent_preview def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks): @@ -408,12 +405,12 @@ class MultiTalkCrossAttnPatch: self.ref_target_masks = ref_target_masks def __call__(self, kwargs): - x = kwargs["x"] - block_idx = kwargs.get("block_idx", 0) + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_idx", None) if block_idx is None: return torch.zeros_like(x) + x = kwargs["x"] - transformer_options = kwargs.get("transformer_options", {}) audio_embeds = transformer_options.get("audio_embeds") x_ref_attn_map = None @@ -440,152 +437,35 @@ class MultiTalkApplyModelWrapper: return samples -class InfiniteTalkOuterSampleLoopingWrapper: - def __init__(self, init_previous_frames, encoded_audio, model_patch, audio_scale, max_frames, frame_window_size, motion_frame_count=9, vae=None, ref_target_masks=None): - self.init_previous_frames = init_previous_frames - self.encoded_audio = encoded_audio - self.total_audio_frames = encoded_audio[0].shape[0] - self.max_frames = max_frames - self.frame_window_size = frame_window_size - self.latent_window_size = (frame_window_size - 1) // 4 + 1 +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent self.model_patch = model_patch - self.audio_scale = audio_scale - self.motion_frame_count = motion_frame_count - self.vae = vae - self.ref_target_masks = ref_target_masks - - def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, **kwargs): - # init variables - previous_frames = motion_frames_latent = None - init_from_cond = False - frame_offset = audio_start = latent_frame_offset = latent_start_idx = 0 - audio_end = self.frame_window_size - latent_end_idx = self.latent_window_size - decoded_results = [] + self.is_extend = is_extend + def __call__(self, executor, *args, **kwargs): model_patcher = executor.class_obj.model_patcher model_options = executor.class_obj.model_options process_latent_in = model_patcher.model.process_latent_in - dtype = model_patcher.model_dtype() - # when extending from previous frames - if self.init_previous_frames is not None: - decoded_results.append(self.init_previous_frames.unsqueeze(0)) - previous_frames = self.init_previous_frames # should we grow the results here or rely on using batch image nodes in the workflow? - if previous_frames.shape[0] < self.motion_frame_count: - previous_frames = torch.cat([previous_frames[:1].repeat(self.motion_frame_count - previous_frames.shape[0], 1, 1, 1), previous_frames], dim=0) - motion_frames = previous_frames[-self.motion_frame_count:] - frame_offset = previous_frames.shape[0] - self.motion_frame_count + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] - # add/replace current cross-attention patch to model options - model_options["transformer_options"].setdefault("patches", {}).setdefault("cross_attn", []).append( - MultiTalkCrossAttnPatch(self.model_patch, self.audio_scale, ref_target_masks=self.ref_target_masks) - ) + # run the sampling process + result = executor(*args, **kwargs) - frames_needed = math.ceil(min(self.max_frames, self.total_audio_frames) / 81) * 81 - estimated_iterations = frames_needed // (self.frame_window_size - self.motion_frame_count) - total_steps = (sigmas.shape[-1] - 1) * estimated_iterations - logging.info(f"InfiniteTalk estimated loop iterations: {estimated_iterations}, Total steps: {total_steps}") - - # custom previewer callback for full loop progress bar - pbar = comfy.utils.ProgressBar(total_steps) - previewer = latent_preview.get_previewer(model_patcher.load_device, model_patcher.model.latent_format) - - def custom_callback(step, x0, x, total_steps): - preview_bytes = None - if previewer: - preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0) - pbar.update_absolute(pbar.current+1, preview=preview_bytes) - - # outer loop start for multiple frame windows - for i in range(estimated_iterations): - - # first frame to InfinityTalk always has to be noise free encoded image - # if no previous samples provided, try to get I2V cond latent from positive cond - - if previous_frames is None: - concat_latent_image = executor.class_obj.conds["positive"][0].get("concat_latent_image", None) - if concat_latent_image is not None: - motion_frames_latent = concat_latent_image[:, :, :1] - overlap = 1 - init_from_cond = True - # else, use previous samples' last frames as first frame - else: - audio_start = frame_offset - audio_end = audio_start + self.frame_window_size - latent_start_idx = latent_frame_offset - latent_end_idx = latent_start_idx + self.latent_window_size - - if len(motion_frames.shape) == 5: - motion_frames = motion_frames.squeeze(0) - spacial_compression = self.vae.spacial_compression_encode() - if (motion_frames.shape[-3], motion_frames.shape[-2]) != (noise.shape[-2] * spacial_compression, noise.shape[-1] * spacial_compression): - motion_frames = comfy.utils.common_upscale( - motion_frames.movedim(-1, 1), - noise.shape[-1] * spacial_compression, noise.shape[-2] * spacial_compression, - "bilinear", "center") - - motion_frames_latent = self.vae.encode(motion_frames) - overlap = motion_frames_latent.shape[2] - - audio_embed = project_audio_features(self.model_patch.model.audio_proj, self.encoded_audio, audio_start, audio_end).to(dtype) - model_options["transformer_options"]["audio_embeds"] = audio_embed - - # model input first latents need to always be replaced on every step - if motion_frames_latent is not None: - wrappers = model_options["transformer_options"]["wrappers"] - w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) - w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(motion_frames_latent))] - - # Slice possible encoded latent_image for vid2vid - if latent_image is not None and torch.count_nonzero(latent_image) > 0: - # Check if we have enough latents - if latent_end_idx > latent_image.shape[2]: - # This window needs more frames - pad the latent_image at the end - pad_length = latent_end_idx - latent_image.shape[2] - last_frame = latent_image[:, :, -1:].repeat(1, 1, pad_length, 1, 1) - latent_image = torch.cat([latent_image, last_frame], dim=2) - new_noise_frames = torch.randn_like(latent_image[:, :, -pad_length:], device=noise.device, dtype=noise.dtype) - noise = torch.cat([noise, new_noise_frames], dim=2) - noise = noise[:, :, latent_start_idx:latent_end_idx] - latent_image = latent_image[:, :, latent_start_idx:latent_end_idx] - #if denoise_mask is not None: # todo: check if denoise mask needs adjustment for latent_image changes - - - # run the sampling process - result = executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=custom_callback, disable_pbar=False, seed=seed, **kwargs) - - #insert motion frames before decoding - if previous_frames is not None and not init_from_cond: - result = torch.cat([motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) - - previous_frames = self.vae.decode(result) - motion_frames = previous_frames[:, -self.motion_frame_count:] - - # Track frame progress - new_frame_count = previous_frames.shape[1] - self.motion_frame_count - frame_offset += new_frame_count - - motion_latent_count = (self.motion_frame_count - 1) // 4 + 1 if self.motion_frame_count > 0 else 0 - new_latent_count = self.latent_window_size - motion_latent_count - - latent_frame_offset += new_latent_count - - if init_from_cond: - decoded_results.append(previous_frames) - init_from_cond = False - else: - decoded_results.append(previous_frames[:, self.motion_frame_count:]) - - return torch.cat(decoded_results, dim=1) + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + return result def to(self, device_or_dtype): if isinstance(device_or_dtype, torch.device): - if self.init_previous_frames is not None: - self.init_previous_frames = self.init_previous_frames.to(device_or_dtype) - if self.encoded_audio is not None: - self.encoded_audio = [ea.to(device_or_dtype) for ea in self.encoded_audio] - if self.ref_target_masks is not None: - self.ref_target_masks = self.ref_target_masks.to(device_or_dtype) + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) return self diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d80704a12..c82ba0e78 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -11,6 +11,7 @@ import numpy as np from typing import Tuple from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import logging class WanImageToVideo(io.ComfyNode): @classmethod @@ -1288,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) -from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleLoopingWrapper +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features class WanInfiniteTalkToVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -1310,7 +1311,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode): io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), - io.Int.Input("frame_window_size", default=81, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames to generate in one window."), io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), io.Image.Input("previous_frames", optional=True), @@ -1320,15 +1320,16 @@ class WanInfiniteTalkToVideo(io.ComfyNode): io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_image"), ], ) @classmethod - def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, frame_window_size, + def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: - if frame_window_size > length: - frame_window_size = length + if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: + raise ValueError("Not enough previous frames provided.") if audio_encoder_output_2 is not None: if mask_1 is None or mask_2 is None: raise ValueError("Masks must be provided if two audio encoder outputs are used.") @@ -1339,10 +1340,10 @@ class WanInfiniteTalkToVideo(io.ComfyNode): raise ValueError("Second audio encoder output must be provided if two masks are used.") ref_masks = torch.cat([mask_1, mask_2]) - latent = torch.zeros([1, 16, ((frame_window_size - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: - start_image = comfy.utils.common_upscale(start_image[:frame_window_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - image = torch.ones((frame_window_size, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 image[:start_image.shape[0]] = start_image concat_latent_image = vae.encode(image[:, :, :, :3]) @@ -1399,30 +1400,40 @@ class WanInfiniteTalkToVideo(io.ComfyNode): ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0] token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1) - - init_previous_frames = None + # when extending from previous frames if previous_frames is not None: - init_previous_frames = previous_frames[:, :, :, :3] + motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_offset = previous_frames.shape[0] - motion_frame_count + audio_start = frame_offset + audio_end = audio_start + length + logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}") + motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) + trim_image = motion_frame_count + else: + audio_start = trim_image = 0 + audio_end = length + motion_frames_latent = concat_latent_image[:, :, :1] + + audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) + model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed + + # add outer sample wrapper model_patched.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "infinite_talk_outer_sample", - InfiniteTalkOuterSampleLoopingWrapper( - init_previous_frames, - encoded_audio_list, + InfiniteTalkOuterSampleWrapper( + motion_frames_latent, model_patch, - audio_scale, - length, - frame_window_size, - motion_frame_count, - vae=vae, - ref_target_masks=token_ref_target_masks) - ) + is_extend=previous_frames is not None, + )) + # add cross-attention patch + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn") out_latent = {} out_latent["samples"] = latent - return io.NodeOutput(model_patched, positive, negative, out_latent) + return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) class WanExtension(ComfyExtension):