mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-05 16:26:48 +08:00
VocoderWithBWE: run forward pass in fp32.
This commit is contained in:
parent
a500f1edac
commit
652778b7cc
@ -656,8 +656,17 @@ class MelSTFT(nn.Module):
|
|||||||
class VocoderWithBWE(torch.nn.Module):
|
class VocoderWithBWE(torch.nn.Module):
|
||||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||||
|
|
||||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
Chains a base vocoder (mel -> low-rate waveform) with a BWE stage that upsamples
|
||||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||||
|
|
||||||
|
The forward pass is wrapped in ``torch.autocast(dtype=float32)`` so that
|
||||||
|
all operations run in fp32 regardless of weight dtype or caller context.
|
||||||
|
The BigVGAN v2 architecture passes signals through 108 sequential
|
||||||
|
convolutions (18 AMPBlocks x 6 convs) plus 36 anti-aliased activations;
|
||||||
|
bfloat16 accumulation errors compound through this chain and degrade
|
||||||
|
spectral metrics by 40-90% while adding only ~70 MB peak VRAM and ~20 ms
|
||||||
|
latency compared to native bf16. Weights may remain in bf16 for storage
|
||||||
|
savings -- autocast upcasts them per-op at kernel level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@ -697,17 +706,19 @@ class VocoderWithBWE(torch.nn.Module):
|
|||||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||||
|
|
||||||
def forward(self, mel_spec):
|
def forward(self, mel_spec):
|
||||||
x = self.vocoder(mel_spec)
|
input_dtype = mel_spec.dtype
|
||||||
_, _, T_low = x.shape
|
with torch.autocast(device_type=mel_spec.device.type, dtype=torch.float32):
|
||||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
x = self.vocoder(mel_spec.float())
|
||||||
|
_, _, T_low = x.shape
|
||||||
|
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||||
|
|
||||||
remainder = T_low % self.hop_length
|
remainder = T_low % self.hop_length
|
||||||
if remainder != 0:
|
if remainder != 0:
|
||||||
x = F.pad(x, (0, self.hop_length - remainder))
|
x = F.pad(x, (0, self.hop_length - remainder))
|
||||||
|
|
||||||
mel = self._compute_mel(x)
|
mel = self._compute_mel(x)
|
||||||
residual = self.bwe_generator(mel)
|
residual = self.bwe_generator(mel)
|
||||||
skip = self.resampler(x)
|
skip = self.resampler(x)
|
||||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||||
|
|
||||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
return torch.clamp(residual + skip, -1, 1)[..., :T_out].to(input_dtype)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user