fix vae issue

This commit is contained in:
Yousef Rafat 2025-12-18 14:13:41 +02:00
parent acb9a11c6f
commit db74a27870
3 changed files with 24 additions and 10 deletions

View File

@ -1342,6 +1342,8 @@ class NaDiT(nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
conditions = kwargs.get("condition")
x = x.movedim(1, -1)
conditions = conditions.movedim(1, -1)
try:
neg_cond, pos_cond = context.chunk(2, dim=0)
@ -1418,6 +1420,9 @@ class NaDiT(nn.Module):
out = torch.stack(vid)
try:
pos, neg = out.chunk(2)
return torch.cat([neg, pos])
except:
ut = torch.cat([neg, pos])
out = out.movedim(-1, 1)
return out
except:
out = out.movedim(-1, 1)
return out

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
import comfy.model_management
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.modules.attention import optimized_attention
@ -1552,15 +1553,19 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return z, p
def decode(self, z: torch.FloatTensor):
z = z.movedim(1, -1)
latent = z.unsqueeze(0)
scale = 0.9152
shift = 0
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
if latent.ndim == 4:
latent = latent.unsqueeze(2)
if z.ndim == 4:
z = z.unsqueeze(2)
target_device = comfy.model_management.get_torch_device()
self.to(target_device)
x = super().decode(latent).squeeze(2)
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")

View File

@ -4,6 +4,7 @@ import torch
import math
from einops import rearrange
import comfy.model_management
import torch.nn.functional as F
from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda, Normalize
@ -116,11 +117,12 @@ class SeedVR2InputProcessing(io.ComfyNode):
@classmethod
def execute(cls, images, vae, resolution_height, resolution_width):
comfy.model_management.load_models_gpu([vae.patcher], force_full_load=True)
device = vae.patcher.load_device
offload_device = vae.patcher.offload_device
vae = vae.first_stage_model
scale = 0.9152; shift = 0
offload_device = comfy.model_management.intermediate_device()
vae_model = vae.first_stage_model
scale = 0.9152; shift = 0
if images.dim() != 5: # add the t dim
images = images.unsqueeze(0)
images = images.permute(0, 1, 4, 2, 3)
@ -142,14 +144,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = cut_videos(images)
images = rearrange(images, "b t c h w -> b c t h w")
vae = vae.to(device)
images = images.to(device)
latent = vae.encode(images)[0]
vae = vae.to(offload_device)
latent = vae_model.encode(images)[0]
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
latent = latent.to(offload_device)
return io.NodeOutput({"samples": latent})
@ -189,6 +191,8 @@ class SeedVR2Conditioning(io.ComfyNode):
t = timestep_transform(t, shape)
cond = inter(vae_conditioning, aug_noises, t)
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
condition = condition.movedim(-1, 1)
noises = noises.movedim(-1, 1)
pos_shape = pos_cond.shape[0]
neg_shape = neg_cond.shape[0]