mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-31 14:03:35 +08:00
VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32.
torch.autocast does not support float32 as a target dtype. Use enabled=False to disable autocast for the block, relying on the explicit .float() cast on the input. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
652778b7cc
commit
7d2b6f74bf
@ -659,14 +659,13 @@ class VocoderWithBWE(torch.nn.Module):
|
||||
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
|
||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||
|
||||
The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that
|
||||
all operations run in fp32 regardless of weight dtype or caller context.
|
||||
The BigVGAN v2 architecture passes signals through 108 sequential
|
||||
convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations;
|
||||
bfloat16 accumulation errors compound through this chain and degrade
|
||||
spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms
|
||||
latency compared to native bf16. Weights may remain in bf16 for storage
|
||||
savings -- autocast upcasts them per-op at kernel level.
|
||||
The forward pass disables autocast and explicitly casts inputs to fp32
|
||||
so that all operations run in full precision regardless of any outer
|
||||
autocast context. The BigVGAN v2 architecture passes signals through
|
||||
108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36
|
||||
anti-aliased activations; bfloat16 accumulation errors compound through
|
||||
this chain and degrade spectral metrics by 40-90% while adding only
|
||||
~70 MB peak VRAM and ~20 ms latency compared to native bf16.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@ -707,7 +706,7 @@ class VocoderWithBWE(torch.nn.Module):
|
||||
|
||||
def forward(self, mel_spec):
|
||||
input_dtype = mel_spec.dtype
|
||||
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
|
||||
with torch.autocast(device_type=mel_spec.device.type, enabled=False):
|
||||
x = self.vocoder(mel_spec.float())
|
||||
_, _, T_low = x.shape
|
||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||
|
||||
Loading…
Reference in New Issue
Block a user