mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 12:20:16 +08:00
vae fix
This commit is contained in:
parent
58e7cea796
commit
ebd945ce3d
@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
@ -398,6 +399,11 @@ class InflatedCausalConv3d(nn.Conv3d):
|
|||||||
error_msgs,
|
error_msgs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
|
||||||
|
if times == 0:
|
||||||
|
return tensor
|
||||||
|
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
|
||||||
|
|
||||||
class Upsample3D(nn.Module):
|
class Upsample3D(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -509,6 +515,9 @@ class Upsample3D(nn.Module):
|
|||||||
z=self.temporal_ratio,
|
z=self.temporal_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.temporal_up:
|
||||||
|
hidden_states[0] = remove_head(hidden_states[0])
|
||||||
|
|
||||||
if not self.slicing:
|
if not self.slicing:
|
||||||
hidden_states = hidden_states[0]
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
@ -1296,11 +1305,55 @@ class Decoder3D(nn.Module):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
class VideoAutoencoderKL(nn.Module):
|
def wavelet_blur(image: Tensor, radius: int):
|
||||||
"""
|
"""
|
||||||
We simply inherit the model code from diffusers
|
Apply wavelet blur to the input tensor.
|
||||||
"""
|
"""
|
||||||
|
# input shape: (1, 3, H, W)
|
||||||
|
# convolution kernel
|
||||||
|
kernel_vals = [
|
||||||
|
[0.0625, 0.125, 0.0625],
|
||||||
|
[0.125, 0.25, 0.125],
|
||||||
|
[0.0625, 0.125, 0.0625],
|
||||||
|
]
|
||||||
|
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||||
|
# add channel dimensions to the kernel to make it a 4D tensor
|
||||||
|
kernel = kernel[None, None]
|
||||||
|
# repeat the kernel across all input channels
|
||||||
|
kernel = kernel.repeat(3, 1, 1, 1)
|
||||||
|
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||||
|
# apply convolution
|
||||||
|
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def wavelet_decomposition(image: Tensor, levels=5):
|
||||||
|
"""
|
||||||
|
Apply wavelet decomposition to the input tensor.
|
||||||
|
This function only returns the low frequency & the high frequency.
|
||||||
|
"""
|
||||||
|
high_freq = torch.zeros_like(image)
|
||||||
|
for i in range(levels):
|
||||||
|
radius = 2 ** i
|
||||||
|
low_freq = wavelet_blur(image, radius)
|
||||||
|
high_freq += (image - low_freq)
|
||||||
|
image = low_freq
|
||||||
|
|
||||||
|
return high_freq, low_freq
|
||||||
|
|
||||||
|
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
||||||
|
"""
|
||||||
|
Apply wavelet decomposition, so that the content will have the same color as the style.
|
||||||
|
"""
|
||||||
|
# calculate the wavelet decomposition of the content feature
|
||||||
|
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||||
|
del content_low_freq
|
||||||
|
# calculate the wavelet decomposition of the style feature
|
||||||
|
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||||
|
del style_high_freq
|
||||||
|
# reconstruct the content feature with the style's high frequency
|
||||||
|
return content_high_freq + style_low_freq
|
||||||
|
|
||||||
|
class VideoAutoencoderKL(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int = 3,
|
in_channels: int = 3,
|
||||||
@ -1478,6 +1531,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
self.spatial_downsample_factor = spatial_downsample_factor
|
self.spatial_downsample_factor = spatial_downsample_factor
|
||||||
self.temporal_downsample_factor = temporal_downsample_factor
|
self.temporal_downsample_factor = temporal_downsample_factor
|
||||||
self.freeze_encoder = freeze_encoder
|
self.freeze_encoder = freeze_encoder
|
||||||
|
self.original_image_video = None
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, x: torch.FloatTensor):
|
def forward(self, x: torch.FloatTensor):
|
||||||
@ -1487,6 +1541,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
return x, z, p
|
return x, z, p
|
||||||
|
|
||||||
def encode(self, x: torch.FloatTensor):
|
def encode(self, x: torch.FloatTensor):
|
||||||
|
# we need to keep a reference to the image/video so we later can do a colour fix later
|
||||||
|
self.original_image_video = x
|
||||||
if x.ndim == 4:
|
if x.ndim == 4:
|
||||||
x = x.unsqueeze(2)
|
x = x.unsqueeze(2)
|
||||||
x = x.to(next(self.parameters()).dtype)
|
x = x.to(next(self.parameters()).dtype)
|
||||||
@ -1502,18 +1558,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
latent = latent / scale + shift
|
latent = latent / scale + shift
|
||||||
latent = rearrange(latent, "b ... c -> b c ...")
|
latent = rearrange(latent, "b ... c -> b c ...")
|
||||||
latent = latent.squeeze(2)
|
latent = latent.squeeze(2)
|
||||||
|
|
||||||
if z.ndim == 4:
|
if z.ndim == 4:
|
||||||
z = z.unsqueeze(2)
|
z = z.unsqueeze(2)
|
||||||
x = super().decode(latent).squeeze(2)
|
x = super().decode(latent).squeeze(2)
|
||||||
return x
|
|
||||||
|
|
||||||
def preprocess(self, x: torch.Tensor):
|
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
||||||
# x should in [B, C, T, H, W], [B, C, H, W]
|
x = wavelet_reconstruction(x, input)
|
||||||
assert x.ndim == 4 or x.size(2) % 4 == 1
|
|
||||||
return x
|
|
||||||
|
|
||||||
def postprocess(self, x: torch.Tensor):
|
|
||||||
# x should in [B, C, T, H, W], [B, C, H, W]
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user