import torch import torch.nn.functional as F import torch.nn as nn import comfy.ops import numpy as np import math ops = comfy.ops.disable_weight_init LRELU_SLOPE = 0.1 def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) # --------------------------------------------------------------------------- # Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 # Adopted from https://github.com/NVIDIA/BigVGAN # --------------------------------------------------------------------------- def _sinc(x: torch.Tensor): return torch.where( x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x, ) def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): even = kernel_size % 2 == 0 half_size = kernel_size // 2 delta_f = 4 * half_width A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 if A > 50.0: beta = 0.1102 * (A - 8.7) elif A >= 21.0: beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) else: beta = 0.0 window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) if even: time = torch.arange(-half_size, half_size) + 0.5 else: time = torch.arange(kernel_size) - half_size if cutoff == 0: filter_ = torch.zeros_like(time) else: filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) filter_ /= filter_.sum() filter = filter_.view(1, 1, kernel_size) return filter class LowPassFilter1d(nn.Module): def __init__( self, cutoff=0.5, half_width=0.6, stride=1, padding=True, padding_mode="replicate", kernel_size=12, ): super().__init__() if cutoff < -0.0: raise ValueError("Minimum cutoff must be larger than zero.") if cutoff > 0.5: raise ValueError("A cutoff above 0.5 does not make sense.") self.kernel_size = kernel_size self.even = kernel_size % 2 == 0 self.pad_left = kernel_size // 2 - int(self.even) self.pad_right = kernel_size // 2 self.stride = stride self.padding = padding self.padding_mode = padding_mode filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) self.register_buffer("filter", filter) def forward(self, x): _, C, _ = x.shape if self.padding: x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) class UpSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): super().__init__() self.ratio = ratio self.stride = ratio if window_type == "hann": # Hann-windowed sinc filter — identical to torchaudio.functional.resample # with its default parameters (rolloff=0.99, lowpass_filter_width=6). # Uses replicate boundary padding, matching the reference resampler exactly. rolloff = 0.99 lowpass_filter_width = 6 width = math.ceil(lowpass_filter_width / rolloff) self.kernel_size = 2 * width * ratio + 1 self.pad = width self.pad_left = 2 * width * ratio self.pad_right = self.kernel_size - ratio t = (torch.arange(self.kernel_size) / ratio - width) * rolloff t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) else: # Kaiser-windowed sinc filter (BigVGAN default). self.kernel_size = ( int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size ) self.pad = self.kernel_size // ratio - 1 self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 self.pad_right = ( self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 ) filter = kaiser_sinc_filter1d( cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size ) self.register_buffer("filter", filter, persistent=persistent) def forward(self, x): _, C, _ = x.shape x = F.pad(x, (self.pad, self.pad), mode="replicate") x = self.ratio * F.conv_transpose1d( x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C ) x = x[..., self.pad_left : -self.pad_right] return x class DownSample1d(nn.Module): def __init__(self, ratio=2, kernel_size=None): super().__init__() self.ratio = ratio self.kernel_size = ( int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size ) self.lowpass = LowPassFilter1d( cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size, ) def forward(self, x): return self.lowpass(x) class Activation1d(nn.Module): def __init__( self, activation, up_ratio=2, down_ratio=2, up_kernel_size=12, down_kernel_size=12, ): super().__init__() self.act = activation self.upsample = UpSample1d(up_ratio, up_kernel_size) self.downsample = DownSample1d(down_ratio, down_kernel_size) def forward(self, x): x = self.upsample(x) x = self.act(x) x = self.downsample(x) return x # --------------------------------------------------------------------------- # BigVGAN v2 activations (Snake / SnakeBeta) # --------------------------------------------------------------------------- class Snake(nn.Module): def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True ): super().__init__() self.alpha_logscale = alpha_logscale self.alpha = nn.Parameter( torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha ) self.alpha.requires_grad = alpha_trainable self.eps = 1e-9 def forward(self, x): a = self.alpha.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: a = torch.exp(a) return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) class SnakeBeta(nn.Module): def __init__( self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True ): super().__init__() self.alpha_logscale = alpha_logscale self.alpha = nn.Parameter( torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha ) self.alpha.requires_grad = alpha_trainable self.beta = nn.Parameter( torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha ) self.beta.requires_grad = alpha_trainable self.eps = 1e-9 def forward(self, x): a = self.alpha.unsqueeze(0).unsqueeze(-1) b = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: a = torch.exp(a) b = torch.exp(b) return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) # --------------------------------------------------------------------------- # BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) # --------------------------------------------------------------------------- class AMPBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): super().__init__() act_cls = SnakeBeta if activation == "snakebeta" else Snake self.convs1 = nn.ModuleList( [ ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]), ), ] ) self.convs2 = nn.ModuleList( [ ops.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ), ] ) self.acts1 = nn.ModuleList( [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] ) self.acts2 = nn.ModuleList( [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] ) def forward(self, x): for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): xt = a1(x) xt = c1(xt) xt = a2(xt) xt = c2(xt) x = x + xt return x # --------------------------------------------------------------------------- # HiFi-GAN residual blocks # --------------------------------------------------------------------------- class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() self.convs1 = nn.ModuleList( [ ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]), ), ] ) self.convs2 = nn.ModuleList( [ ops.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1), ), ] ) def forward(self, x): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) xt = c1(xt) xt = F.leaky_relu(xt, LRELU_SLOPE) xt = c2(xt) x = xt + x return x class ResBlock2(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() self.convs = nn.ModuleList( [ ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]), ), ops.Conv1d( channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]), ), ] ) def forward(self, x): for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) xt = c(xt) x = xt + x return x class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): super(Vocoder, self).__init__() if config is None: config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) activation = config.get("activation", "snake") use_bias_at_final = config.get("use_bias_at_final", True) # "output_sample_rate" is not present in recent checkpoint configs. # When absent (None), AudioVAE.output_sample_rate computes it as: # sample_rate * vocoder.upsample_factor / mel_hop_length # where upsample_factor = product of all upsample stride lengths, # and mel_hop_length is loaded from the autoencoder config at # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") self.resblock = config.get("resblock", "1") self.use_tanh_at_final = config.get("use_tanh_at_final", True) self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) if self.resblock == "1": resblock_cls = ResBlock1 elif self.resblock == "2": resblock_cls = ResBlock2 elif self.resblock == "AMP1": resblock_cls = AMPBlock1 else: raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): self.ups.append( ops.ConvTranspose1d( upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2, ) ) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): if self.resblock == "AMP1": self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) else: self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 if self.resblock == "AMP1": act_cls = SnakeBeta if activation == "snakebeta" else Snake self.act_post = Activation1d(act_cls(ch)) else: self.act_post = nn.LeakyReLU() self.conv_post = ops.Conv1d( ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], "upsample_rates": [5, 4, 2, 2, 2], "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", "activation": "snake", "use_bias_at_final": True, "use_tanh_at_final": True, } return config def forward(self, x): """ Forward pass of the vocoder. Args: x: Input spectrogram tensor. Can be: - 3D: (batch_size, channels, time_steps) for mono - 4D: (batch_size, 2, channels, time_steps) for stereo Returns: Audio tensor of shape (batch_size, out_channels, audio_length) """ if x.dim() == 4: # stereo assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) for i in range(self.num_upsamples): if self.resblock != "AMP1": x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels x = self.act_post(x) x = self.conv_post(x) if self.apply_final_activation: if self.use_tanh_at_final: x = torch.tanh(x) else: x = torch.clamp(x, -1, 1) return x class _STFTFn(nn.Module): """Implements STFT as a convolution with precomputed DFT × Hann-window bases. The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal Hann window are stored as buffers and loaded from the checkpoint. Using the exact bfloat16 bases from training ensures the mel values fed to the BWE generator are bit-identical to what it was trained on. """ def __init__(self, filter_length: int, hop_length: int, win_length: int): super().__init__() self.hop_length = hop_length self.win_length = win_length n_freqs = filter_length // 2 + 1 self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Compute magnitude and phase spectrogram from a batch of waveforms. Applies causal (left-only) padding of win_length - hop_length samples so that each output frame depends only on past and present input — no lookahead. The STFT is computed by convolving the padded signal with forward_basis. Args: y: Waveform tensor of shape (B, T). Returns: magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). Computed in float32 for numerical stability, then cast back to the input dtype. """ if y.dim() == 2: y = y.unsqueeze(1) # (B, 1, T) left_pad = max(0, self.win_length - self.hop_length) # causal: left-only y = F.pad(y, (left_pad, 0)) spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) n_freqs = spec.shape[1] // 2 real, imag = spec[:, :n_freqs], spec[:, n_freqs:] magnitude = torch.sqrt(real ** 2 + imag ** 2) phase = torch.atan2(imag.float(), real.float()).to(real.dtype) return magnitude, phase class MelSTFT(nn.Module): """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input waveform and projecting the linear magnitude spectrum onto the mel filterbank. The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). """ def __init__( self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int, sampling_rate: int, mel_fmin: float, mel_fmax: float, ): super().__init__() self.stft_fn = _STFTFn(filter_length, hop_length, win_length) n_freqs = filter_length // 2 + 1 self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) def mel_spectrogram( self, y: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute log-mel spectrogram and auxiliary spectral quantities. Args: y: Waveform tensor of shape (B, T). Returns: log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). """ magnitude, phase = self.stft_fn(y) energy = torch.norm(magnitude, dim=1) mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) log_mel = torch.log(torch.clamp(mel, min=1e-5)) return log_mel, magnitude, phase, energy class VocoderWithBWE(torch.nn.Module): """Vocoder with bandwidth extension (BWE) for higher sample rate output. Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. """ def __init__(self, config): super().__init__() vocoder_config = config["vocoder"] bwe_config = config["bwe"] self.vocoder = Vocoder(config=vocoder_config) self.bwe_generator = Vocoder( config={**bwe_config, "apply_final_activation": False} ) self.input_sample_rate = bwe_config["input_sampling_rate"] self.output_sample_rate = bwe_config["output_sampling_rate"] self.hop_length = bwe_config["hop_length"] self.mel_stft = MelSTFT( filter_length=bwe_config["n_fft"], hop_length=bwe_config["hop_length"], win_length=bwe_config["n_fft"], n_mel_channels=bwe_config["num_mels"], sampling_rate=bwe_config["input_sampling_rate"], mel_fmin=0.0, mel_fmax=bwe_config["input_sampling_rate"] / 2.0, ) self.resampler = UpSample1d( ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], persistent=False, window_type="hann", ) def _compute_mel(self, audio): """Compute log-mel spectrogram from waveform using causal STFT bases.""" B, C, T = audio.shape flat = audio.reshape(B * C, -1) # (B*C, T) mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) def forward(self, mel_spec): x = self.vocoder(mel_spec) _, _, T_low = x.shape T_out = T_low * self.output_sample_rate // self.input_sample_rate remainder = T_low % self.hop_length if remainder != 0: x = F.pad(x, (0, self.hop_length - remainder)) mel = self._compute_mel(x) residual = self.bwe_generator(mel) skip = self.resampler(x) assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" return torch.clamp(residual + skip, -1, 1)[..., :T_out]