VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32.

torch.autocast does not support float32 as a target dtype. Use
enabled=False to disable autocast for the block, relying on the
explicit .float() cast on the input.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
btalesnik 2026-03-29 12:23:43 +00:00
parent 652778b7cc
commit 7d2b6f74bf

View File

@ -659,14 +659,13 @@ class VocoderWithBWE(torch.nn.Module):
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that The forward pass disables autocast and explicitly casts inputs to fp32
all operations run in fp32 regardless of weight dtype or caller context. so that all operations run in full precision regardless of any outer
The BigVGAN v2 architecture passes signals through 108 sequential autocast context. The BigVGAN v2 architecture passes signals through
convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations; 108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36
bfloat16 accumulation errors compound through this chain and degrade anti-aliased activations; bfloat16 accumulation errors compound through
spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms this chain and degrade spectral metrics by 40-90% while adding only
latency compared to native bf16. Weights may remain in bf16 for storage ~70 MB peak VRAM and ~20 ms latency compared to native bf16.
savings -- autocast upcasts them per-op at kernel level.
""" """
def __init__(self, config): def __init__(self, config):
@ -707,7 +706,7 @@ class VocoderWithBWE(torch.nn.Module):
def forward(self, mel_spec): def forward(self, mel_spec):
input_dtype = mel_spec.dtype input_dtype = mel_spec.dtype
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32): with torch.autocast(device_type=mel_spec.device.type, enabled=False):
x = self.vocoder(mel_spec.float()) x = self.vocoder(mel_spec.float())
_, _, T_low = x.shape _, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate T_out = T_low * self.output_sample_rate // self.input_sample_rate