diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 1843d8c83..d106cba3d 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -154,35 +154,35 @@ class HunyuanVideo15FirstLastFrameToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) - - concat_latent_image = torch.zeros((batch_size, 32, latent.shape[2], latent.shape[3], latent.shape[4]), + + concat_latent_image = torch.zeros((batch_size, 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) - - mask = torch.ones((1, 1, latent.shape[2], latent.shape[3], latent.shape[4]), + + mask = torch.ones((1, 1, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) - + if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - + encoded_start = vae.encode(start_image[:, :, :, :3]) - + concat_latent_image[:, :, :encoded_start.shape[2], :, :] = encoded_start - + start_frames_in_latent = ((start_image.shape[0] - 1) // 4) + 1 mask[:, :, :start_frames_in_latent] = 0.0 - + if end_image is not None: end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - + encoded_end = vae.encode(end_image[:, :, :, :3]) - + end_frames_in_latent = ((end_image.shape[0] - 1) // 4) + 1 concat_latent_image[:, :, -end_frames_in_latent:, :, :] = encoded_end[:, :, -end_frames_in_latent:, :, :] - + mask[:, :, -end_frames_in_latent:] = 0.0 - + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})