mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-05 08:16:47 +08:00
Revert "VocoderWithBWE: use autocast(enabled=False) instead of dtype=float32."
This reverts commit 7d2b6f74bf.
This commit is contained in:
parent
7d2b6f74bf
commit
87e2477a94
@ -659,13 +659,14 @@ class VocoderWithBWE(torch.nn.Module):
|
|||||||
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
|
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
|
||||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||||
|
|
||||||
The forward pass disables autocast and explicitly casts inputs to fp32
|
The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that
|
||||||
so that all operations run in full precision regardless of any outer
|
all operations run in fp32 regardless of weight dtype or caller context.
|
||||||
autocast context. The BigVGAN v2 architecture passes signals through
|
The BigVGAN v2 architecture passes signals through 108 sequential
|
||||||
108 sequential convolutions (18 AMPBlocks x 6 convs) plus 36
|
convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations;
|
||||||
anti-aliased activations; bfloat16 accumulation errors compound through
|
bfloat16 accumulation errors compound through this chain and degrade
|
||||||
this chain and degrade spectral metrics by 40-90% while adding only
|
spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms
|
||||||
~70 MB peak VRAM and ~20 ms latency compared to native bf16.
|
latency compared to native bf16. Weights may remain in bf16 for storage
|
||||||
|
savings -- autocast upcasts them per-op at kernel level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@ -706,7 +707,7 @@ class VocoderWithBWE(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, mel_spec):
|
def forward(self, mel_spec):
|
||||||
input_dtype = mel_spec.dtype
|
input_dtype = mel_spec.dtype
|
||||||
with torch.autocast(device_type=mel_spec.device.type, enabled=False):
|
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
|
||||||
x = self.vocoder(mel_spec.float())
|
x = self.vocoder(mel_spec.float())
|
||||||
_, _, T_low = x.shape
|
_, _, T_low = x.shape
|
||||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user