This commit is contained in:
Yousef Rafat 2025-12-17 00:09:38 +02:00
parent 58e7cea796
commit ebd945ce3d

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.modules.attention import optimized_attention
@ -398,6 +399,11 @@ class InflatedCausalConv3d(nn.Conv3d):
error_msgs,
)
def remove_head(tensor: Tensor, times: int = 1) -> Tensor:
if times == 0:
return tensor
return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2)
class Upsample3D(nn.Module):
def __init__(
@ -509,6 +515,9 @@ class Upsample3D(nn.Module):
z=self.temporal_ratio,
)
if self.temporal_up:
hidden_states[0] = remove_head(hidden_states[0])
if not self.slicing:
hidden_states = hidden_states[0]
@ -1296,11 +1305,55 @@ class Decoder3D(nn.Module):
return sample
class VideoAutoencoderKL(nn.Module):
def wavelet_blur(image: Tensor, radius: int):
"""
We simply inherit the model code from diffusers
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
class VideoAutoencoderKL(nn.Module):
def __init__(
self,
in_channels: int = 3,
@ -1478,6 +1531,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor
self.freeze_encoder = freeze_encoder
self.original_image_video = None
super().__init__(*args, **kwargs)
def forward(self, x: torch.FloatTensor):
@ -1487,6 +1541,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return x, z, p
def encode(self, x: torch.FloatTensor):
# we need to keep a reference to the image/video so we later can do a colour fix later
self.original_image_video = x
if x.ndim == 4:
x = x.unsqueeze(2)
x = x.to(next(self.parameters()).dtype)
@ -1502,18 +1558,13 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
if z.ndim == 4:
z = z.unsqueeze(2)
x = super().decode(latent).squeeze(2)
return x
def preprocess(self, x: torch.Tensor):
# x should in [B, C, T, H, W], [B, C, H, W]
assert x.ndim == 4 or x.size(2) % 4 == 1
return x
def postprocess(self, x: torch.Tensor):
# x should in [B, C, T, H, W], [B, C, H, W]
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
x = wavelet_reconstruction(x, input)
return x
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):