Fix LTX audio VAE BF16 dtype mismatch

This commit is contained in:
jjdejong 2026-05-16 22:43:27 +02:00
parent 26515acd23
commit 6c9af68b08

View File

@ -135,6 +135,16 @@ class AudioVAE(torch.nn.Module):
n_fft=autoencoder_config["n_fft"],
)
@staticmethod
def _module_device_dtype(module: torch.nn.Module) -> tuple[torch.device, torch.dtype]:
for tensor in module.parameters():
if tensor.is_floating_point():
return tensor.device, tensor.dtype
for tensor in module.buffers():
if tensor.is_floating_point():
return tensor.device, tensor.dtype
return torch.device("cpu"), torch.float32
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""
@ -154,6 +164,8 @@ class AudioVAE(torch.nn.Module):
waveform, waveform_sample_rate, device=waveform.device
)
autoencoder_device, autoencoder_dtype = self._module_device_dtype(self.autoencoder)
mel_spec = mel_spec.to(device=autoencoder_device, dtype=autoencoder_dtype)
latents = self.autoencoder.encode(mel_spec)
posterior = DiagonalGaussianDistribution(latents)
latent_mode = posterior.mode()
@ -165,6 +177,8 @@ class AudioVAE(torch.nn.Module):
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape
autoencoder_device, autoencoder_dtype = self._module_device_dtype(self.autoencoder)
latents = latents.to(device=autoencoder_device, dtype=autoencoder_dtype)
latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape)
@ -197,6 +211,8 @@ class AudioVAE(torch.nn.Module):
elif audio_channels != 2:
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
vocoder_device, vocoder_dtype = self._module_device_dtype(self.vocoder)
vocoder_input = vocoder_input.to(device=vocoder_device, dtype=vocoder_dtype)
return self.vocoder(vocoder_input)
@property