Add inputs for character replacement to the WanAnimateToVideo node. (#9960)

This commit is contained in:
comfyanonymous 2025-09-19 23:24:10 -07:00 committed by GitHub
parent e8df53b764
commit 66241cef31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1127,6 +1127,8 @@ class WanAnimateToVideo(io.ComfyNode):
io.Image.Input("face_video", optional=True), io.Image.Input("face_video", optional=True),
io.Image.Input("pose_video", optional=True), io.Image.Input("pose_video", optional=True),
io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Int.Input("continue_motion_max_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Image.Input("background_video", optional=True),
io.Mask.Input("character_mask", optional=True),
io.Image.Input("continue_motion", optional=True), io.Image.Input("continue_motion", optional=True),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."), io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="The amount of frames to seek in all the input videos. Used for generating longer videos by chunk. Connect to the video_frame_offset output of the previous node for extending a video."),
], ],
@ -1142,7 +1144,7 @@ class WanAnimateToVideo(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None) -> io.NodeOutput: def execute(cls, positive, negative, vae, width, height, length, batch_size, continue_motion_max_frames, video_frame_offset, reference_image=None, clip_vision_output=None, face_video=None, pose_video=None, continue_motion=None, background_video=None, character_mask=None) -> io.NodeOutput:
trim_to_pose_video = False trim_to_pose_video = False
latent_length = ((length - 1) // 4) + 1 latent_length = ((length - 1) // 4) + 1
latent_width = width // 8 latent_width = width // 8
@ -1154,7 +1156,7 @@ class WanAnimateToVideo(io.ComfyNode):
image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) image = comfy.utils.common_upscale(reference_image[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
concat_latent_image = vae.encode(image[:, :, :, :3]) concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.zeros((1, 1, concat_latent_image.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype) mask = torch.zeros((1, 4, concat_latent_image.shape[-3], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=concat_latent_image.device, dtype=concat_latent_image.dtype)
trim_latent += concat_latent_image.shape[2] trim_latent += concat_latent_image.shape[2]
ref_motion_latent_length = 0 ref_motion_latent_length = 0
@ -1206,11 +1208,37 @@ class WanAnimateToVideo(io.ComfyNode):
positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video}) positive = node_helpers.conditioning_set_values(positive, {"face_video_pixels": face_video})
negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0}) negative = node_helpers.conditioning_set_values(negative, {"face_video_pixels": face_video * 0.0 - 1.0})
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2) ref_images_num = max(0, ref_motion_latent_length * 4 - 3)
mask_refmotion = torch.ones((1, 1, latent_length, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype) if background_video is not None:
if continue_motion is not None: if background_video.shape[0] > video_frame_offset:
mask_refmotion[:, :, :ref_motion_latent_length] = 0.0 background_video = background_video[video_frame_offset:]
background_video = comfy.utils.common_upscale(background_video[:length].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
if background_video.shape[0] > ref_images_num:
image[ref_images_num:background_video.shape[0] - ref_images_num] = background_video[ref_images_num:]
mask_refmotion = torch.ones((1, 1, latent_length * 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=mask.device, dtype=mask.dtype)
if continue_motion is not None:
mask_refmotion[:, :, :ref_motion_latent_length * 4] = 0.0
if character_mask is not None:
if character_mask.shape[0] > video_frame_offset or character_mask.shape[0] == 1:
if character_mask.shape[0] == 1:
character_mask = character_mask.repeat((length,) + (1,) * (character_mask.ndim - 1))
else:
character_mask = character_mask[video_frame_offset:]
if character_mask.ndim == 3:
character_mask = character_mask.unsqueeze(1)
character_mask = character_mask.movedim(0, 1)
if character_mask.ndim == 4:
character_mask = character_mask.unsqueeze(1)
character_mask = comfy.utils.common_upscale(character_mask[:, :, :length], concat_latent_image.shape[-1], concat_latent_image.shape[-2], "nearest-exact", "center")
if character_mask.shape[2] > ref_images_num:
mask_refmotion[:, :, ref_images_num:character_mask.shape[2] + ref_images_num] = character_mask[:, :, ref_images_num:]
concat_latent_image = torch.cat((concat_latent_image, vae.encode(image[:, :, :, :3])), dim=2)
mask_refmotion = mask_refmotion.view(1, mask_refmotion.shape[2] // 4, 4, mask_refmotion.shape[3], mask_refmotion.shape[4]).transpose(1, 2)
mask = torch.cat((mask, mask_refmotion), dim=2) mask = torch.cat((mask, mask_refmotion), dim=2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})