mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 00:37:53 +08:00
Fix LTX audio VAE BF16 dtype mismatch
This commit is contained in:
parent
26515acd23
commit
6c9af68b08
@ -135,6 +135,16 @@ class AudioVAE(torch.nn.Module):
|
||||
n_fft=autoencoder_config["n_fft"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _module_device_dtype(module: torch.nn.Module) -> tuple[torch.device, torch.dtype]:
|
||||
for tensor in module.parameters():
|
||||
if tensor.is_floating_point():
|
||||
return tensor.device, tensor.dtype
|
||||
for tensor in module.buffers():
|
||||
if tensor.is_floating_point():
|
||||
return tensor.device, tensor.dtype
|
||||
return torch.device("cpu"), torch.float32
|
||||
|
||||
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
|
||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||
|
||||
@ -154,6 +164,8 @@ class AudioVAE(torch.nn.Module):
|
||||
waveform, waveform_sample_rate, device=waveform.device
|
||||
)
|
||||
|
||||
autoencoder_device, autoencoder_dtype = self._module_device_dtype(self.autoencoder)
|
||||
mel_spec = mel_spec.to(device=autoencoder_device, dtype=autoencoder_dtype)
|
||||
latents = self.autoencoder.encode(mel_spec)
|
||||
posterior = DiagonalGaussianDistribution(latents)
|
||||
latent_mode = posterior.mode()
|
||||
@ -165,6 +177,8 @@ class AudioVAE(torch.nn.Module):
|
||||
"""Decode normalized latent tensors into an audio waveform."""
|
||||
original_shape = latents.shape
|
||||
|
||||
autoencoder_device, autoencoder_dtype = self._module_device_dtype(self.autoencoder)
|
||||
latents = latents.to(device=autoencoder_device, dtype=autoencoder_dtype)
|
||||
latents = self.normalizer.denormalize(latents)
|
||||
|
||||
target_shape = self.target_shape_from_latents(original_shape)
|
||||
@ -197,6 +211,8 @@ class AudioVAE(torch.nn.Module):
|
||||
elif audio_channels != 2:
|
||||
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
||||
|
||||
vocoder_device, vocoder_dtype = self._module_device_dtype(self.vocoder)
|
||||
vocoder_input = vocoder_input.to(device=vocoder_device, dtype=vocoder_dtype)
|
||||
return self.vocoder(vocoder_input)
|
||||
|
||||
@property
|
||||
|
||||
Loading…
Reference in New Issue
Block a user