VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32.

torch.autocast does not support float32 as a target dtype. Use
enabled=False to disable autocast for the block, relying on the
explicit .float() cast on the input.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
btalesnik 2026-03-29 12:23:43 +00:00
parent 652778b7cc
commit 7d2b6f74bf

View File

@ -659,14 +659,13 @@ class VocoderWithBWE(torch.nn.Module):
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that
all operations run in fp32 regardless of weight dtype or caller context.
The BigVGAN v2 architecture passes signals through 108 sequential
convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations;
bfloat16 accumulation errors compound through this chain and degrade
spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms
latency compared to native bf16. Weights may remain in bf16 for storage
savings -- autocast upcasts them per-op at kernel level.
The forward pass disables autocast and explicitly casts inputs to fp32
so that all operations run in full precision regardless of any outer
autocast context. The BigVGAN v2 architecture passes signals through
108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36
anti-aliased activations; bfloat16 accumulation errors compound through
this chain and degrade spectral metrics by 40-90% while adding only
~70 MB peak VRAM and ~20 ms latency compared to native bf16.
"""
def __init__(self, config):
@ -707,7 +706,7 @@ class VocoderWithBWE(torch.nn.Module):
def forward(self, mel_spec):
input_dtype = mel_spec.dtype
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
with torch.autocast(device_type=mel_spec.device.type, enabled=False):
x = self.vocoder(mel_spec.float())
_, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate