mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Remove looping functionality, keep extension functionality
This commit is contained in:
parent
8d62661a9f
commit
d53e62913d
@ -1,10 +1,7 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
import math
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import logging
|
||||
import latent_preview
|
||||
|
||||
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks):
|
||||
@ -408,12 +405,12 @@ class MultiTalkCrossAttnPatch:
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs["x"]
|
||||
block_idx = kwargs.get("block_idx", 0)
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
block_idx = transformer_options.get("block_idx", None)
|
||||
if block_idx is None:
|
||||
return torch.zeros_like(x)
|
||||
x = kwargs["x"]
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
|
||||
x_ref_attn_map = None
|
||||
@ -440,152 +437,35 @@ class MultiTalkApplyModelWrapper:
|
||||
return samples
|
||||
|
||||
|
||||
class InfiniteTalkOuterSampleLoopingWrapper:
|
||||
def __init__(self, init_previous_frames, encoded_audio, model_patch, audio_scale, max_frames, frame_window_size, motion_frame_count=9, vae=None, ref_target_masks=None):
|
||||
self.init_previous_frames = init_previous_frames
|
||||
self.encoded_audio = encoded_audio
|
||||
self.total_audio_frames = encoded_audio[0].shape[0]
|
||||
self.max_frames = max_frames
|
||||
self.frame_window_size = frame_window_size
|
||||
self.latent_window_size = (frame_window_size - 1) // 4 + 1
|
||||
class InfiniteTalkOuterSampleWrapper:
|
||||
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
|
||||
self.motion_frames_latent = motion_frames_latent
|
||||
self.model_patch = model_patch
|
||||
self.audio_scale = audio_scale
|
||||
self.motion_frame_count = motion_frame_count
|
||||
self.vae = vae
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, **kwargs):
|
||||
# init variables
|
||||
previous_frames = motion_frames_latent = None
|
||||
init_from_cond = False
|
||||
frame_offset = audio_start = latent_frame_offset = latent_start_idx = 0
|
||||
audio_end = self.frame_window_size
|
||||
latent_end_idx = self.latent_window_size
|
||||
decoded_results = []
|
||||
self.is_extend = is_extend
|
||||
|
||||
def __call__(self, executor, *args, **kwargs):
|
||||
model_patcher = executor.class_obj.model_patcher
|
||||
model_options = executor.class_obj.model_options
|
||||
process_latent_in = model_patcher.model.process_latent_in
|
||||
dtype = model_patcher.model_dtype()
|
||||
|
||||
# when extending from previous frames
|
||||
if self.init_previous_frames is not None:
|
||||
decoded_results.append(self.init_previous_frames.unsqueeze(0))
|
||||
previous_frames = self.init_previous_frames # should we grow the results here or rely on using batch image nodes in the workflow?
|
||||
if previous_frames.shape[0] < self.motion_frame_count:
|
||||
previous_frames = torch.cat([previous_frames[:1].repeat(self.motion_frame_count - previous_frames.shape[0], 1, 1, 1), previous_frames], dim=0)
|
||||
motion_frames = previous_frames[-self.motion_frame_count:]
|
||||
frame_offset = previous_frames.shape[0] - self.motion_frame_count
|
||||
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
|
||||
if self.motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
|
||||
|
||||
# add/replace current cross-attention patch to model options
|
||||
model_options["transformer_options"].setdefault("patches", {}).setdefault("cross_attn", []).append(
|
||||
MultiTalkCrossAttnPatch(self.model_patch, self.audio_scale, ref_target_masks=self.ref_target_masks)
|
||||
)
|
||||
# run the sampling process
|
||||
result = executor(*args, **kwargs)
|
||||
|
||||
frames_needed = math.ceil(min(self.max_frames, self.total_audio_frames) / 81) * 81
|
||||
estimated_iterations = frames_needed // (self.frame_window_size - self.motion_frame_count)
|
||||
total_steps = (sigmas.shape[-1] - 1) * estimated_iterations
|
||||
logging.info(f"InfiniteTalk estimated loop iterations: {estimated_iterations}, Total steps: {total_steps}")
|
||||
|
||||
# custom previewer callback for full loop progress bar
|
||||
pbar = comfy.utils.ProgressBar(total_steps)
|
||||
previewer = latent_preview.get_previewer(model_patcher.load_device, model_patcher.model.latent_format)
|
||||
|
||||
def custom_callback(step, x0, x, total_steps):
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0)
|
||||
pbar.update_absolute(pbar.current+1, preview=preview_bytes)
|
||||
|
||||
# outer loop start for multiple frame windows
|
||||
for i in range(estimated_iterations):
|
||||
|
||||
# first frame to InfinityTalk always has to be noise free encoded image
|
||||
# if no previous samples provided, try to get I2V cond latent from positive cond
|
||||
|
||||
if previous_frames is None:
|
||||
concat_latent_image = executor.class_obj.conds["positive"][0].get("concat_latent_image", None)
|
||||
if concat_latent_image is not None:
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
overlap = 1
|
||||
init_from_cond = True
|
||||
# else, use previous samples' last frames as first frame
|
||||
else:
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + self.frame_window_size
|
||||
latent_start_idx = latent_frame_offset
|
||||
latent_end_idx = latent_start_idx + self.latent_window_size
|
||||
|
||||
if len(motion_frames.shape) == 5:
|
||||
motion_frames = motion_frames.squeeze(0)
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
if (motion_frames.shape[-3], motion_frames.shape[-2]) != (noise.shape[-2] * spacial_compression, noise.shape[-1] * spacial_compression):
|
||||
motion_frames = comfy.utils.common_upscale(
|
||||
motion_frames.movedim(-1, 1),
|
||||
noise.shape[-1] * spacial_compression, noise.shape[-2] * spacial_compression,
|
||||
"bilinear", "center")
|
||||
|
||||
motion_frames_latent = self.vae.encode(motion_frames)
|
||||
overlap = motion_frames_latent.shape[2]
|
||||
|
||||
audio_embed = project_audio_features(self.model_patch.model.audio_proj, self.encoded_audio, audio_start, audio_end).to(dtype)
|
||||
model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# model input first latents need to always be replaced on every step
|
||||
if motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(motion_frames_latent))]
|
||||
|
||||
# Slice possible encoded latent_image for vid2vid
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0:
|
||||
# Check if we have enough latents
|
||||
if latent_end_idx > latent_image.shape[2]:
|
||||
# This window needs more frames - pad the latent_image at the end
|
||||
pad_length = latent_end_idx - latent_image.shape[2]
|
||||
last_frame = latent_image[:, :, -1:].repeat(1, 1, pad_length, 1, 1)
|
||||
latent_image = torch.cat([latent_image, last_frame], dim=2)
|
||||
new_noise_frames = torch.randn_like(latent_image[:, :, -pad_length:], device=noise.device, dtype=noise.dtype)
|
||||
noise = torch.cat([noise, new_noise_frames], dim=2)
|
||||
noise = noise[:, :, latent_start_idx:latent_end_idx]
|
||||
latent_image = latent_image[:, :, latent_start_idx:latent_end_idx]
|
||||
#if denoise_mask is not None: # todo: check if denoise mask needs adjustment for latent_image changes
|
||||
|
||||
|
||||
# run the sampling process
|
||||
result = executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=custom_callback, disable_pbar=False, seed=seed, **kwargs)
|
||||
|
||||
#insert motion frames before decoding
|
||||
if previous_frames is not None and not init_from_cond:
|
||||
result = torch.cat([motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
previous_frames = self.vae.decode(result)
|
||||
motion_frames = previous_frames[:, -self.motion_frame_count:]
|
||||
|
||||
# Track frame progress
|
||||
new_frame_count = previous_frames.shape[1] - self.motion_frame_count
|
||||
frame_offset += new_frame_count
|
||||
|
||||
motion_latent_count = (self.motion_frame_count - 1) // 4 + 1 if self.motion_frame_count > 0 else 0
|
||||
new_latent_count = self.latent_window_size - motion_latent_count
|
||||
|
||||
latent_frame_offset += new_latent_count
|
||||
|
||||
if init_from_cond:
|
||||
decoded_results.append(previous_frames)
|
||||
init_from_cond = False
|
||||
else:
|
||||
decoded_results.append(previous_frames[:, self.motion_frame_count:])
|
||||
|
||||
return torch.cat(decoded_results, dim=1)
|
||||
# insert motion frames before decoding
|
||||
if self.is_extend:
|
||||
overlap = self.motion_frames_latent.shape[2]
|
||||
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
return result
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.init_previous_frames is not None:
|
||||
self.init_previous_frames = self.init_previous_frames.to(device_or_dtype)
|
||||
if self.encoded_audio is not None:
|
||||
self.encoded_audio = [ea.to(device_or_dtype) for ea in self.encoded_audio]
|
||||
if self.ref_target_masks is not None:
|
||||
self.ref_target_masks = self.ref_target_masks.to(device_or_dtype)
|
||||
if self.motion_frames_latent is not None:
|
||||
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
|
||||
return self
|
||||
|
||||
@ -11,6 +11,7 @@ import numpy as np
|
||||
from typing import Tuple
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -1288,7 +1289,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleLoopingWrapper
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -1310,7 +1311,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||
io.Int.Input("frame_window_size", default=81, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames to generate in one window."),
|
||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Image.Input("previous_frames", optional=True),
|
||||
@ -1320,15 +1320,16 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, frame_window_size,
|
||||
def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||
|
||||
if frame_window_size > length:
|
||||
frame_window_size = length
|
||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||
raise ValueError("Not enough previous frames provided.")
|
||||
if audio_encoder_output_2 is not None:
|
||||
if mask_1 is None or mask_2 is None:
|
||||
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||
@ -1339,10 +1340,10 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||
ref_masks = torch.cat([mask_1, mask_2])
|
||||
|
||||
latent = torch.zeros([1, 16, ((frame_window_size - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:frame_window_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((frame_window_size, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
@ -1399,30 +1400,40 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
|
||||
init_previous_frames = None
|
||||
# when extending from previous frames
|
||||
if previous_frames is not None:
|
||||
init_previous_frames = previous_frames[:, :, :, :3]
|
||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + length
|
||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||
|
||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||
trim_image = motion_frame_count
|
||||
else:
|
||||
audio_start = trim_image = 0
|
||||
audio_end = length
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
|
||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# add outer sample wrapper
|
||||
model_patched.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||
"infinite_talk_outer_sample",
|
||||
InfiniteTalkOuterSampleLoopingWrapper(
|
||||
init_previous_frames,
|
||||
encoded_audio_list,
|
||||
InfiniteTalkOuterSampleWrapper(
|
||||
motion_frames_latent,
|
||||
model_patch,
|
||||
audio_scale,
|
||||
length,
|
||||
frame_window_size,
|
||||
motion_frame_count,
|
||||
vae=vae,
|
||||
ref_target_masks=token_ref_target_masks)
|
||||
)
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale, ref_target_masks=token_ref_target_masks), "cross_attn")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent)
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user