diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 6c58f044b..ef07b24e0 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange +from torch import Tensor from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.modules.attention import optimized_attention @@ -398,6 +399,11 @@ class InflatedCausalConv3d(nn.Conv3d): error_msgs, ) +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + class Upsample3D(nn.Module): def __init__( @@ -509,6 +515,9 @@ class Upsample3D(nn.Module): z=self.temporal_ratio, ) + if self.temporal_up: + hidden_states[0] = remove_head(hidden_states[0]) + if not self.slicing: hidden_states = hidden_states[0] @@ -1296,11 +1305,55 @@ class Decoder3D(nn.Module): return sample -class VideoAutoencoderKL(nn.Module): +def wavelet_blur(image: Tensor, radius: int): """ - We simply inherit the model code from diffusers + Apply wavelet blur to the input tensor. """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq + +class VideoAutoencoderKL(nn.Module): def __init__( self, in_channels: int = 3, @@ -1478,6 +1531,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor self.freeze_encoder = freeze_encoder + self.original_image_video = None super().__init__(*args, **kwargs) def forward(self, x: torch.FloatTensor): @@ -1487,6 +1541,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): return x, z, p def encode(self, x: torch.FloatTensor): + # we need to keep a reference to the image/video so we later can do a colour fix later + self.original_image_video = x if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) @@ -1502,18 +1558,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent / scale + shift latent = rearrange(latent, "b ... c -> b c ...") latent = latent.squeeze(2) + if z.ndim == 4: z = z.unsqueeze(2) x = super().decode(latent).squeeze(2) - return x - def preprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] - assert x.ndim == 4 or x.size(2) % 4 == 1 - return x - - def postprocess(self, x: torch.Tensor): - # x should in [B, C, T, H, W], [B, C, H, W] + input = rearrange(self.original_image_video[0], "c t h w -> t c h w") + x = wavelet_reconstruction(x, input) return x def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):