diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py index af81280eb..3c8830c17 100644 --- a/comfy/ldm/ace/vae/music_dcae_pipeline.py +++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py @@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module): else: self.source_sample_rate = source_sample_rate - # self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) - self.transform = transforms.Compose([ transforms.Normalize(0.5, 0.5), ]) @@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module): self.scale_factor = 0.1786 self.shift_factor = -1.9091 - def load_audio(self, audio_path): - audio, sr = torchaudio.load(audio_path) - return audio, sr - def forward_mel(self, audios): mels = [] for i in range(len(audios)): @@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module): latent = self.dcae.encoder(mel.unsqueeze(0)) latents.append(latent) latents = torch.cat(latents, dim=0) - # latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long() latents = (latents - self.shift_factor) * self.scale_factor return latents - # return latents, latent_lengths @torch.no_grad() def decode(self, latents, audio_lengths=None, sr=None): @@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module): wav = self.vocoder.decode(mels[0]).squeeze(1) if sr is not None: - # resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype) wav = torchaudio.functional.resample(wav, 44100, sr) - # wav = resampler(wav) else: sr = 44100 pred_wavs.append(wav) @@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module): if audio_lengths is not None: pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)] return torch.stack(pred_wavs) - # return sr, pred_wavs def forward(self, audios, audio_lengths=None, sr=None): latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)