mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 04:30:51 +08:00
fix vae issue
This commit is contained in:
parent
acb9a11c6f
commit
db74a27870
@ -1342,6 +1342,8 @@ class NaDiT(nn.Module):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
conditions = kwargs.get("condition")
|
||||
x = x.movedim(1, -1)
|
||||
conditions = conditions.movedim(1, -1)
|
||||
|
||||
try:
|
||||
neg_cond, pos_cond = context.chunk(2, dim=0)
|
||||
@ -1418,6 +1420,9 @@ class NaDiT(nn.Module):
|
||||
out = torch.stack(vid)
|
||||
try:
|
||||
pos, neg = out.chunk(2)
|
||||
return torch.cat([neg, pos])
|
||||
except:
|
||||
ut = torch.cat([neg, pos])
|
||||
out = out.movedim(-1, 1)
|
||||
return out
|
||||
except:
|
||||
out = out.movedim(-1, 1)
|
||||
return out
|
||||
|
||||
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
@ -1552,15 +1553,19 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
return z, p
|
||||
|
||||
def decode(self, z: torch.FloatTensor):
|
||||
z = z.movedim(1, -1)
|
||||
latent = z.unsqueeze(0)
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
latent = latent / scale + shift
|
||||
latent = rearrange(latent, "b ... c -> b c ...")
|
||||
latent = latent.squeeze(2)
|
||||
|
||||
if latent.ndim == 4:
|
||||
latent = latent.unsqueeze(2)
|
||||
|
||||
if z.ndim == 4:
|
||||
z = z.unsqueeze(2)
|
||||
target_device = comfy.model_management.get_torch_device()
|
||||
self.to(target_device)
|
||||
x = super().decode(latent).squeeze(2)
|
||||
|
||||
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
||||
|
||||
@ -4,6 +4,7 @@ import torch
|
||||
import math
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda, Normalize
|
||||
@ -116,11 +117,12 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
||||
comfy.model_management.load_models_gpu([vae.patcher], force_full_load=True)
|
||||
device = vae.patcher.load_device
|
||||
offload_device = vae.patcher.offload_device
|
||||
vae = vae.first_stage_model
|
||||
scale = 0.9152; shift = 0
|
||||
|
||||
offload_device = comfy.model_management.intermediate_device()
|
||||
vae_model = vae.first_stage_model
|
||||
scale = 0.9152; shift = 0
|
||||
if images.dim() != 5: # add the t dim
|
||||
images = images.unsqueeze(0)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
@ -142,14 +144,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
images = cut_videos(images)
|
||||
|
||||
images = rearrange(images, "b t c h w -> b c t h w")
|
||||
vae = vae.to(device)
|
||||
images = images.to(device)
|
||||
latent = vae.encode(images)[0]
|
||||
vae = vae.to(offload_device)
|
||||
latent = vae_model.encode(images)[0]
|
||||
|
||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||
latent = rearrange(latent, "b c ... -> b ... c")
|
||||
|
||||
latent = (latent - shift) * scale
|
||||
latent = latent.to(offload_device)
|
||||
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
@ -189,6 +191,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
t = timestep_transform(t, shape)
|
||||
cond = inter(vae_conditioning, aug_noises, t)
|
||||
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
||||
condition = condition.movedim(-1, 1)
|
||||
noises = noises.movedim(-1, 1)
|
||||
|
||||
pos_shape = pos_cond.shape[0]
|
||||
neg_shape = neg_cond.shape[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user