mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
fix vae issue
This commit is contained in:
parent
acb9a11c6f
commit
db74a27870
@ -1342,6 +1342,8 @@ class NaDiT(nn.Module):
|
|||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
conditions = kwargs.get("condition")
|
conditions = kwargs.get("condition")
|
||||||
|
x = x.movedim(1, -1)
|
||||||
|
conditions = conditions.movedim(1, -1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
neg_cond, pos_cond = context.chunk(2, dim=0)
|
neg_cond, pos_cond = context.chunk(2, dim=0)
|
||||||
@ -1418,6 +1420,9 @@ class NaDiT(nn.Module):
|
|||||||
out = torch.stack(vid)
|
out = torch.stack(vid)
|
||||||
try:
|
try:
|
||||||
pos, neg = out.chunk(2)
|
pos, neg = out.chunk(2)
|
||||||
return torch.cat([neg, pos])
|
ut = torch.cat([neg, pos])
|
||||||
except:
|
out = out.movedim(-1, 1)
|
||||||
|
return out
|
||||||
|
except:
|
||||||
|
out = out.movedim(-1, 1)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
@ -1552,15 +1553,19 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
return z, p
|
return z, p
|
||||||
|
|
||||||
def decode(self, z: torch.FloatTensor):
|
def decode(self, z: torch.FloatTensor):
|
||||||
|
z = z.movedim(1, -1)
|
||||||
latent = z.unsqueeze(0)
|
latent = z.unsqueeze(0)
|
||||||
scale = 0.9152
|
scale = 0.9152
|
||||||
shift = 0
|
shift = 0
|
||||||
latent = latent / scale + shift
|
latent = latent / scale + shift
|
||||||
latent = rearrange(latent, "b ... c -> b c ...")
|
latent = rearrange(latent, "b ... c -> b c ...")
|
||||||
latent = latent.squeeze(2)
|
latent = latent.squeeze(2)
|
||||||
|
|
||||||
|
if latent.ndim == 4:
|
||||||
|
latent = latent.unsqueeze(2)
|
||||||
|
|
||||||
if z.ndim == 4:
|
target_device = comfy.model_management.get_torch_device()
|
||||||
z = z.unsqueeze(2)
|
self.to(target_device)
|
||||||
x = super().decode(latent).squeeze(2)
|
x = super().decode(latent).squeeze(2)
|
||||||
|
|
||||||
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import torch
|
|||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torchvision.transforms import functional as TVF
|
from torchvision.transforms import functional as TVF
|
||||||
from torchvision.transforms import Lambda, Normalize
|
from torchvision.transforms import Lambda, Normalize
|
||||||
@ -116,11 +117,12 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
def execute(cls, images, vae, resolution_height, resolution_width):
|
||||||
|
comfy.model_management.load_models_gpu([vae.patcher], force_full_load=True)
|
||||||
device = vae.patcher.load_device
|
device = vae.patcher.load_device
|
||||||
offload_device = vae.patcher.offload_device
|
|
||||||
vae = vae.first_stage_model
|
|
||||||
scale = 0.9152; shift = 0
|
|
||||||
|
|
||||||
|
offload_device = comfy.model_management.intermediate_device()
|
||||||
|
vae_model = vae.first_stage_model
|
||||||
|
scale = 0.9152; shift = 0
|
||||||
if images.dim() != 5: # add the t dim
|
if images.dim() != 5: # add the t dim
|
||||||
images = images.unsqueeze(0)
|
images = images.unsqueeze(0)
|
||||||
images = images.permute(0, 1, 4, 2, 3)
|
images = images.permute(0, 1, 4, 2, 3)
|
||||||
@ -142,14 +144,14 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
images = cut_videos(images)
|
images = cut_videos(images)
|
||||||
|
|
||||||
images = rearrange(images, "b t c h w -> b c t h w")
|
images = rearrange(images, "b t c h w -> b c t h w")
|
||||||
vae = vae.to(device)
|
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
latent = vae.encode(images)[0]
|
latent = vae_model.encode(images)[0]
|
||||||
vae = vae.to(offload_device)
|
|
||||||
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
|
||||||
latent = rearrange(latent, "b c ... -> b ... c")
|
latent = rearrange(latent, "b c ... -> b ... c")
|
||||||
|
|
||||||
latent = (latent - shift) * scale
|
latent = (latent - shift) * scale
|
||||||
|
latent = latent.to(offload_device)
|
||||||
|
|
||||||
return io.NodeOutput({"samples": latent})
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
@ -189,6 +191,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
t = timestep_transform(t, shape)
|
t = timestep_transform(t, shape)
|
||||||
cond = inter(vae_conditioning, aug_noises, t)
|
cond = inter(vae_conditioning, aug_noises, t)
|
||||||
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
|
||||||
|
condition = condition.movedim(-1, 1)
|
||||||
|
noises = noises.movedim(-1, 1)
|
||||||
|
|
||||||
pos_shape = pos_cond.shape[0]
|
pos_shape = pos_cond.shape[0]
|
||||||
neg_shape = neg_cond.shape[0]
|
neg_shape = neg_cond.shape[0]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user