From ddac8fd48e57a681c4a82de06ed2d63a459d436c Mon Sep 17 00:00:00 2001 From: jjdejong Date: Wed, 13 May 2026 22:56:59 +0200 Subject: [PATCH] fix: cast mel_spec to waveform dtype in AudioVAE encode to support --bf16-vae waveform_to_mel() performs mel filterbank computation in float32 regardless of input dtype, discarding the bfloat16 cast applied by the VAE encode path in sd.py. The resulting float32 mel spectrogram is then passed to the bfloat16 autoencoder encoder, causing a RuntimeError at the first conv layer when --bf16-vae is active. Fix by casting mel_spec to waveform.dtype (already set to vae_dtype by the caller) before passing to self.autoencoder.encode(). This is a no-op when --bf16-vae is not used. Co-Authored-By: Claude Sonnet 4.6 --- comfy/ldm/lightricks/vae/audio_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index fa0a00748..4019e32fd 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -193,7 +193,7 @@ class AudioVAE(torch.nn.Module): waveform, waveform_sample_rate, device=self.device_manager.load_device ) - latents = self.autoencoder.encode(mel_spec) + latents = self.autoencoder.encode(mel_spec.to(dtype=waveform.dtype)) posterior = DiagonalGaussianDistribution(latents) latent_mode = posterior.mode()