diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index dd5320c8f..77a951eb4 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -135,6 +135,16 @@ class AudioVAE(torch.nn.Module): n_fft=autoencoder_config["n_fft"], ) + @staticmethod + def _module_device_dtype(module: torch.nn.Module) -> tuple[torch.device, torch.dtype]: + for tensor in module.parameters(): + if tensor.is_floating_point(): + return tensor.device, tensor.dtype + for tensor in module.buffers(): + if tensor.is_floating_point(): + return tensor.device, tensor.dtype + return torch.device("cpu"), torch.float32 + def encode(self, audio, sample_rate=44100) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" @@ -154,6 +164,8 @@ class AudioVAE(torch.nn.Module): waveform, waveform_sample_rate, device=waveform.device ) + autoencoder_device, autoencoder_dtype = self._module_device_dtype(self.autoencoder) + mel_spec = mel_spec.to(device=autoencoder_device, dtype=autoencoder_dtype) latents = self.autoencoder.encode(mel_spec) posterior = DiagonalGaussianDistribution(latents) latent_mode = posterior.mode() @@ -165,6 +177,8 @@ class AudioVAE(torch.nn.Module): """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape + autoencoder_device, autoencoder_dtype = self._module_device_dtype(self.autoencoder) + latents = latents.to(device=autoencoder_device, dtype=autoencoder_dtype) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) @@ -197,6 +211,8 @@ class AudioVAE(torch.nn.Module): elif audio_channels != 2: raise ValueError(f"Unsupported audio_channels: {audio_channels}") + vocoder_device, vocoder_dtype = self._module_device_dtype(self.vocoder) + vocoder_input = vocoder_input.to(device=vocoder_device, dtype=vocoder_dtype) return self.vocoder(vocoder_input) @property