diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index e44048447..119799592 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1342,6 +1342,8 @@ class NaDiT(nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") + x = x.movedim(1, -1) + conditions = conditions.movedim(1, -1) try: neg_cond, pos_cond = context.chunk(2, dim=0) @@ -1418,6 +1420,9 @@ class NaDiT(nn.Module): out = torch.stack(vid) try: pos, neg = out.chunk(2) - return torch.cat([neg, pos]) - except: + ut = torch.cat([neg, pos]) + out = out.movedim(-1, 1) + return out + except: + out = out.movedim(-1, 1) return out diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index ef07b24e0..277f7a697 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from einops import rearrange from torch import Tensor +import comfy.model_management from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention @@ -1552,15 +1553,19 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return z, p def decode(self, z: torch.FloatTensor): + z = z.movedim(1, -1) latent = z.unsqueeze(0) scale = 0.9152 shift = 0 latent = latent / scale + shift latent = rearrange(latent, "b ... c -> b c ...") latent = latent.squeeze(2) + + if latent.ndim == 4: + latent = latent.unsqueeze(2) - if z.ndim == 4: - z = z.unsqueeze(2) + target_device = comfy.model_management.get_torch_device() + self.to(target_device) x = super().decode(latent).squeeze(2) input = rearrange(self.original_image_video[0], "c t h w -> t c h w") diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 08009b4d9..e4022d209 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -4,6 +4,7 @@ import torch import math from einops import rearrange +import comfy.model_management import torch.nn.functional as F from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda, Normalize @@ -116,11 +117,12 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): + comfy.model_management.load_models_gpu([vae.patcher], force_full_load=True) device = vae.patcher.load_device - offload_device = vae.patcher.offload_device - vae = vae.first_stage_model - scale = 0.9152; shift = 0 + offload_device = comfy.model_management.intermediate_device() + vae_model = vae.first_stage_model + scale = 0.9152; shift = 0 if images.dim() != 5: # add the t dim images = images.unsqueeze(0) images = images.permute(0, 1, 4, 2, 3) @@ -142,14 +144,14 @@ class SeedVR2InputProcessing(io.ComfyNode): images = cut_videos(images) images = rearrange(images, "b t c h w -> b c t h w") - vae = vae.to(device) images = images.to(device) - latent = vae.encode(images)[0] - vae = vae.to(offload_device) + latent = vae_model.encode(images)[0] + latent = latent.unsqueeze(2) if latent.ndim == 4 else latent latent = rearrange(latent, "b c ... -> b ... c") latent = (latent - shift) * scale + latent = latent.to(offload_device) return io.NodeOutput({"samples": latent}) @@ -189,6 +191,8 @@ class SeedVR2Conditioning(io.ComfyNode): t = timestep_transform(t, shape) cond = inter(vae_conditioning, aug_noises, t) condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)]) + condition = condition.movedim(-1, 1) + noises = noises.movedim(-1, 1) pos_shape = pos_cond.shape[0] neg_shape = neg_cond.shape[0]