diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index 797484547..a2ab7b2b6 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -659,14 +659,13 @@ class VocoderWithBWE(torch.nn.Module): Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. - The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that - all operations run in fp32 regardless of weight dtype or caller context. - The BigVGAN v2 architecture passes signals through 108 sequential - convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations; - bfloat16 accumulation errors compound through this chain and degrade - spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms - latency compared to native bf16. Weights may remain in bf16 for storage - savings -- autocast upcasts them per-op at kernel level. + The forward pass disables autocast and explicitly casts inputs to fp32 + so that all operations run in full precision regardless of any outer + autocast context. The BigVGAN v2 architecture passes signals through + 108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36 + anti-aliased activations; bfloat16 accumulation errors compound through + this chain and degrade spectral metrics by 40-90% while adding only + ~70 MB peak VRAM and ~20 ms latency compared to native bf16. """ def __init__(self, config): @@ -707,7 +706,7 @@ class VocoderWithBWE(torch.nn.Module): def forward(self, mel_spec): input_dtype = mel_spec.dtype - with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32): + with torch.autocast(device_type=mel_spec.device.type, enabled=False): x = self.vocoder(mel_spec.float()) _, _, T_low = x.shape T_out = T_low * self.output_sample_rate // self.input_sample_rate