Revert "VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32."

This reverts commit 7d2b6f74bf.
This commit is contained in:
btalesnik 2026-03-29 12:25:49 +00:00
parent 7d2b6f74bf
commit 87e2477a94

View File

@ -659,13 +659,14 @@ class VocoderWithBWE(torch.nn.Module):
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
The forward pass disables autocast and explicitly casts inputs to fp32
so that all operations run in full precision regardless of any outer
autocast context. The BigVGAN v2 architecture passes signals through
108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36
anti-aliased activations; bfloat16 accumulation errors compound through
this chain and degrade spectral metrics by 40-90% while adding only
~70 MB peak VRAM and ~20 ms latency compared to native bf16.
The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that
all operations run in fp32 regardless of weight dtype or caller context.
The BigVGAN v2 architecture passes signals through 108 sequential
convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations;
bfloat16 accumulation errors compound through this chain and degrade
spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms
latency compared to native bf16. Weights may remain in bf16 for storage
savings -- autocast upcasts them per-op at kernel level.
"""
def __init__(self, config):
@ -706,7 +707,7 @@ class VocoderWithBWE(torch.nn.Module):
def forward(self, mel_spec):
input_dtype = mel_spec.dtype
with torch.autocast(device_type=mel_spec.device.type, enabled=False):
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
x = self.vocoder(mel_spec.float())
_, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate