ComfyUI/comfy/ldm/hunyuan_foley/vae.py
2025-09-29 22:44:40 +03:00

239 lines
8.1 KiB
Python

import math
import torch
import numpy as np
from typing import List
import torch.nn as nn
from einops import rearrange
from torchvision.transforms import v2
from torch.nn.utils import weight_norm
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
import comfy.ops
ops = comfy.ops.disable_weight_init
# until the higgsv2 pr gets accepted
def WNConv1d(*args, device = None, dtype = None, operations = None, **kwargs):
return weight_norm(operations.Conv1d(*args, **kwargs, device = device, dtype = dtype))
def WNConvTranspose1d(*args, device = None, dtype = None, operations = None, **kwargs):
return weight_norm(operations.ConvTranspose1d(*args, **kwargs, device = device, dtype = dtype))
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels, device = None, dtype = None):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1, device = device, dtype = dtype))
def forward(self, x):
return snake(x, self.alpha)
class DACResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, device = None, dtype = None, operations = None):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim, device = device, dtype = dtype),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad, device = device, dtype = dtype, operations = operations),
Snake1d(dim, device = device, dtype = dtype),
WNConv1d(dim, dim, kernel_size=1, device = device, dtype = dtype, operations = operations),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class DACEncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1, device = None, dtype = None, operations = None):
super().__init__()
self.block = nn.Sequential(
DACResidualUnit(dim // 2, dilation=1, device = device, dtype = dtype, operations = operations),
DACResidualUnit(dim // 2, dilation=3, device = device, dtype = dtype, operations = operations),
DACResidualUnit(dim // 2, dilation=9, device = device, dtype = dtype, operations = operations),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
device = device, dtype = dtype, operations = operations
),
)
def forward(self, x):
return self.block(x)
class DACEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 256,
device = None, dtype = None, operations = None
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DACDecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, device = None, dtype = None, operations = None):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim, device = device, dtype = dtype),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
device = device, dtype = dtype, operations = operations
),
DACResidualUnit(output_dim, dilation=1, device = device, dtype = dtype, operations = operations),
DACResidualUnit(output_dim, dilation=3, device = device, dtype = dtype, operations = operations),
DACResidualUnit(output_dim, dilation=9, device = device, dtype = dtype, operations = operations),
)
def forward(self, x):
return self.block(x)
class DACDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
device = None, dtype = None, operations = None
):
super().__init__()
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations )]
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
layers += [
Snake1d(output_dim, device = device, dtype = dtype),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(torch.nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5],
latent_dim: int = 128,
decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3],
sample_rate: int = 44100,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.sample_rate = sample_rate
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops)
self.decoder = DACDecoder(
latent_dim,
decoder_dim,
decoder_rates,
operations = ops
)
self.sample_rate = sample_rate
def decode(self, z: torch.Tensor):
return self.decoder(z)
def forward(self):
pass
class FoleyVae(torch.nn.Module):
def __init__(self):
super().__init__()
self.dac = DAC()
self.syncformer = Synchformer(None, None, operations = ops)
self.syncformer_preprocess = v2.Compose(
[
v2.Resize(224, interpolation=v2.InterpolationMode.BICUBIC),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
def decode(self, x, vae_options = {}):
return self.dac.decode(x)
def encode(self, x):
return self.syncformer(x)
def video_encoding(self, video, step: int):
if not isinstance(video, torch.Tensor):
video = torch.from_numpy(video).permute(0, 3, 1, 2)
video = self.syncformer_preprocess(video).unsqueeze(0)
seg_len = 16
t = video.size(1)
nseg = max(0, (t - seg_len) // step + 1)
clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)]
data = torch.stack(clips, dim=1)
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))