From 652778b7cc5c22fea6143b7ef1fa88e7b4d77b78 Mon Sep 17 00:00:00 2001 From: btalesnik Date: Thu, 26 Mar 2026 07:50:57 +0000 Subject: [PATCH 1/3] VocoderWithBWE: run forward pass in fp32. --- comfy/ldm/lightricks/vocoders/vocoder.py | 35 ++++++++++++++++-------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index 2481d8bdd..797484547 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -656,8 +656,17 @@ class MelSTFT(nn.Module): class VocoderWithBWE(torch.nn.Module): """Vocoder with bandwidth extension (BWE) for higher sample rate output. - Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + + The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that + all operations run in fp32 regardless of weight dtype or caller context. + The BigVGAN v2 architecture passes signals through 108 sequential + convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations; + bfloat16 accumulation errors compound through this chain and degrade + spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms + latency compared to native bf16. Weights may remain in bf16 for storage + savings -- autocast upcasts them per-op at kernel level. """ def __init__(self, config): @@ -697,17 +706,19 @@ class VocoderWithBWE(torch.nn.Module): return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) def forward(self, mel_spec): - x = self.vocoder(mel_spec) - _, _, T_low = x.shape - T_out = T_low * self.output_sample_rate // self.input_sample_rate + input_dtype = mel_spec.dtype + with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32): + x = self.vocoder(mel_spec.float()) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate - remainder = T_low % self.hop_length - if remainder != 0: - x = F.pad(x, (0, self.hop_length - remainder)) + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) - mel = self._compute_mel(x) - residual = self.bwe_generator(mel) - skip = self.resampler(x) - assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" - return torch.clamp(residual + skip, -1, 1)[..., :T_out] + return torch.clamp(residual + skip, -1, 1)[..., :T_out].to(input_dtype) From 7d2b6f74bf1ca4e6cc05376b3c1fafee167d1a4d Mon Sep 17 00:00:00 2001 From: btalesnik Date: Sun, 29 Mar 2026 12:23:43 +0000 Subject: [PATCH 2/3] VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32. torch.autocast does not support float32 as a target dtype. Use enabled=False to disable autocast for the block, relying on the explicit .float() cast on the input. Co-Authored-By: Claude Opus 4.6 (1M context) --- comfy/ldm/lightricks/vocoders/vocoder.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index 797484547..a2ab7b2b6 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -659,14 +659,13 @@ class VocoderWithBWE(torch.nn.Module): Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. - The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that - all operations run in fp32 regardless of weight dtype or caller context. - The BigVGAN v2 architecture passes signals through 108 sequential - convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations; - bfloat16 accumulation errors compound through this chain and degrade - spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms - latency compared to native bf16. Weights may remain in bf16 for storage - savings -- autocast upcasts them per-op at kernel level. + The forward pass disables autocast and explicitly casts inputs to fp32 + so that all operations run in full precision regardless of any outer + autocast context. The BigVGAN v2 architecture passes signals through + 108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36 + anti-aliased activations; bfloat16 accumulation errors compound through + this chain and degrade spectral metrics by 40-90% while adding only + ~70 MB peak VRAM and ~20 ms latency compared to native bf16. """ def __init__(self, config): @@ -707,7 +706,7 @@ class VocoderWithBWE(torch.nn.Module): def forward(self, mel_spec): input_dtype = mel_spec.dtype - with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32): + with torch.autocast(device_type=mel_spec.device.type, enabled=False): x = self.vocoder(mel_spec.float()) _, _, T_low = x.shape T_out = T_low * self.output_sample_rate // self.input_sample_rate From 87e2477a94bca82c70223f45ff11c98dcf095ffb Mon Sep 17 00:00:00 2001 From: btalesnik Date: Sun, 29 Mar 2026 12:25:49 +0000 Subject: [PATCH 3/3] Revert "VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32." This reverts commit 7d2b6f74bf1ca4e6cc05376b3c1fafee167d1a4d. --- comfy/ldm/lightricks/vocoders/vocoder.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index a2ab7b2b6..797484547 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -659,13 +659,14 @@ class VocoderWithBWE(torch.nn.Module): Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. - The forward pass disables autocast and explicitly casts inputs to fp32 - so that all operations run in full precision regardless of any outer - autocast context. The BigVGAN v2 architecture passes signals through - 108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36 - anti-aliased activations; bfloat16 accumulation errors compound through - this chain and degrade spectral metrics by 40-90% while adding only - ~70 MB peak VRAM and ~20 ms latency compared to native bf16. + The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that + all operations run in fp32 regardless of weight dtype or caller context. + The BigVGAN v2 architecture passes signals through 108 sequential + convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations; + bfloat16 accumulation errors compound through this chain and degrade + spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms + latency compared to native bf16. Weights may remain in bf16 for storage + savings -- autocast upcasts them per-op at kernel level. """ def __init__(self, config): @@ -706,7 +707,7 @@ class VocoderWithBWE(torch.nn.Module): def forward(self, mel_spec): input_dtype = mel_spec.dtype - with torch.autocast(device_type=mel_spec.device.type, enabled=False): + with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32): x = self.vocoder(mel_spec.float()) _, _, T_low = x.shape T_out = T_low * self.output_sample_rate // self.input_sample_rate