mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
vae fix
This commit is contained in:
parent
58e7cea796
commit
ebd945ce3d
@ -4,6 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
@ -398,6 +399,11 @@ class InflatedCausalConv3d(nn.Conv3d):
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
|
||||
if times == 0:
|
||||
return tensor
|
||||
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -509,6 +515,9 @@ class Upsample3D(nn.Module):
|
||||
z=self.temporal_ratio,
|
||||
)
|
||||
|
||||
if self.temporal_up:
|
||||
hidden_states[0] = remove_head(hidden_states[0])
|
||||
|
||||
if not self.slicing:
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
@ -1296,11 +1305,55 @@ class Decoder3D(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def wavelet_blur(image: Tensor, radius: int):
|
||||
"""
|
||||
We simply inherit the model code from diffusers
|
||||
Apply wavelet blur to the input tensor.
|
||||
"""
|
||||
# input shape: (1, 3, H, W)
|
||||
# convolution kernel
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
# add channel dimensions to the kernel to make it a 4D tensor
|
||||
kernel = kernel[None, None]
|
||||
# repeat the kernel across all input channels
|
||||
kernel = kernel.repeat(3, 1, 1, 1)
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
# apply convolution
|
||||
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels=5):
|
||||
"""
|
||||
Apply wavelet decomposition to the input tensor.
|
||||
This function only returns the low frequency & the high frequency.
|
||||
"""
|
||||
high_freq = torch.zeros_like(image)
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq += (image - low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||
"""
|
||||
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||
"""
|
||||
# calculate the wavelet decomposition of the content feature
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq
|
||||
# calculate the wavelet decomposition of the style feature
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq
|
||||
# reconstruct the content feature with the style's high frequency
|
||||
return content_high_freq + style_low_freq
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@ -1478,6 +1531,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
self.spatial_downsample_factor = spatial_downsample_factor
|
||||
self.temporal_downsample_factor = temporal_downsample_factor
|
||||
self.freeze_encoder = freeze_encoder
|
||||
self.original_image_video = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.FloatTensor):
|
||||
@ -1487,6 +1541,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
return x, z, p
|
||||
|
||||
def encode(self, x: torch.FloatTensor):
|
||||
# we need to keep a reference to the image/video so we later can do a colour fix later
|
||||
self.original_image_video = x
|
||||
if x.ndim == 4:
|
||||
x = x.unsqueeze(2)
|
||||
x = x.to(next(self.parameters()).dtype)
|
||||
@ -1502,18 +1558,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
latent = latent / scale + shift
|
||||
latent = rearrange(latent, "b ... c -> b c ...")
|
||||
latent = latent.squeeze(2)
|
||||
|
||||
if z.ndim == 4:
|
||||
z = z.unsqueeze(2)
|
||||
x = super().decode(latent).squeeze(2)
|
||||
return x
|
||||
|
||||
def preprocess(self, x: torch.Tensor):
|
||||
# x should in [B, C, T, H, W], [B, C, H, W]
|
||||
assert x.ndim == 4 or x.size(2) % 4 == 1
|
||||
return x
|
||||
|
||||
def postprocess(self, x: torch.Tensor):
|
||||
# x should in [B, C, T, H, W], [B, C, H, W]
|
||||
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
||||
x = wavelet_reconstruction(x, input)
|
||||
return x
|
||||
|
||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user