fix: cast mel_spec to waveform dtype in AudioVAE encode to support --bf16-vae

waveform_to_mel() performs mel filterbank computation in float32
regardless of input dtype, discarding the bfloat16 cast applied by
the VAE encode path in sd.py. The resulting float32 mel spectrogram
is then passed to the bfloat16 autoencoder encoder, causing a
RuntimeError at the first conv layer when --bf16-vae is active.

Fix by casting mel_spec to waveform.dtype (already set to vae_dtype
by the caller) before passing to self.autoencoder.encode(). This is
a no-op when --bf16-vae is not used.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
jjdejong 2026-05-13 22:56:59 +02:00
parent e4b0bb8305
commit ddac8fd48e

View File

@ -193,7 +193,7 @@ class AudioVAE(torch.nn.Module):
waveform, waveform_sample_rate, device=self.device_manager.load_device
)
latents = self.autoencoder.encode(mel_spec)
latents = self.autoencoder.encode(mel_spec.to(dtype=waveform.dtype))
posterior = DiagonalGaussianDistribution(latents)
latent_mode = posterior.mode()