import json from dataclasses import dataclass import math import torch import torchaudio import comfy.model_management import comfy.model_patcher import comfy.utils as utils from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) from comfy.ldm.lightricks.vocoders.vocoder import Vocoder LATENT_DOWNSAMPLE_FACTOR = 4 @dataclass(frozen=True) class AudioVAEComponentConfig: """Container for model component configuration extracted from metadata.""" autoencoder: dict vocoder: dict @classmethod def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" raw_config = metadata["config"] if isinstance(raw_config, str): parsed_config = json.loads(raw_config) else: parsed_config = raw_config audio_config = parsed_config.get("audio_vae") vocoder_config = parsed_config.get("vocoder") assert audio_config is not None, "Audio VAE config is required for audio VAE" assert vocoder_config is not None, "Vocoder config is required for audio VAE" return cls(autoencoder=audio_config, vocoder=vocoder_config) class ModelDeviceManager: """Manages device placement and GPU residency for the composed model.""" def __init__(self, module: torch.nn.Module): load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.vae_offload_device() self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) def ensure_model_loaded(self) -> None: comfy.model_management.free_memory( self.patcher.model_size(), self.patcher.load_device, ) comfy.model_management.load_model_gpu(self.patcher) def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: return tensor.to(self.patcher.load_device) @property def load_device(self): return self.patcher.load_device class AudioLatentNormalizer: """Applies per-channel statistics in patch space and restores original layout.""" def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): self.patchifier = patchfier self.statistics = statistics_processor def normalize(self, latents: torch.Tensor) -> torch.Tensor: channels = latents.shape[1] freq = latents.shape[3] patched, _ = self.patchifier.patchify(latents) normalized = self.statistics.normalize(patched) return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) def denormalize(self, latents: torch.Tensor) -> torch.Tensor: channels = latents.shape[1] freq = latents.shape[3] patched, _ = self.patchifier.patchify(latents) denormalized = self.statistics.un_normalize(patched) return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) class AudioPreprocessor: """Prepares raw waveforms for the autoencoder by matching training conditions.""" def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): self.target_sample_rate = target_sample_rate self.mel_bins = mel_bins self.mel_hop_length = mel_hop_length self.n_fft = n_fft def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: if source_rate == self.target_sample_rate: return waveform return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) @staticmethod def normalize_amplitude( waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 ) -> torch.Tensor: waveform = waveform - waveform.mean(dim=2, keepdim=True) peak = torch.max(torch.abs(waveform)) + eps scale = peak.clamp(max=max_amplitude) / peak return waveform * scale def waveform_to_mel( self, waveform: torch.Tensor, waveform_sample_rate: int, device ) -> torch.Tensor: waveform = self.resample(waveform, waveform_sample_rate) waveform = self.normalize_amplitude(waveform) mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=self.target_sample_rate, n_fft=self.n_fft, win_length=self.n_fft, hop_length=self.mel_hop_length, f_min=0.0, f_max=self.target_sample_rate / 2.0, n_mels=self.mel_bins, window_fn=torch.hann_window, center=True, pad_mode="reflect", power=1.0, mel_scale="slaney", norm="slaney", ).to(device) mel = mel_transform(waveform) mel = torch.log(torch.clamp(mel, min=1e-5)) return mel.permute(0, 1, 3, 2).contiguous() class AudioVAE(torch.nn.Module): """High-level Audio VAE wrapper exposing encode and decode entry points.""" def __init__(self, state_dict: dict, metadata: dict): super().__init__() component_config = AudioVAEComponentConfig.from_metadata(metadata) vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( AudioPatchifier( patch_size=1, audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, sample_rate=autoencoder_config["sampling_rate"], hop_length=autoencoder_config["mel_hop_length"], is_causal=autoencoder_config["is_causal"], ), self.autoencoder.per_channel_statistics, ) self.preprocessor = AudioPreprocessor( target_sample_rate=autoencoder_config["sampling_rate"], mel_bins=autoencoder_config["mel_bins"], mel_hop_length=autoencoder_config["mel_hop_length"], n_fft=autoencoder_config["n_fft"], ) self.device_manager = ModelDeviceManager(self) def encode(self, audio: dict) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" waveform = audio["waveform"] waveform_sample_rate = audio["sample_rate"] input_device = waveform.device # Ensure that Audio VAE is loaded on the correct device. self.device_manager.ensure_model_loaded() waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: raise ValueError( f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" ) mel_spec = self.preprocessor.waveform_to_mel( waveform, waveform_sample_rate, device=self.device_manager.load_device ) latents = self.autoencoder.encode(mel_spec) posterior = DiagonalGaussianDistribution(latents) latent_mode = posterior.mode() normalized = self.normalizer.normalize(latent_mode) return normalized.to(input_device) def decode(self, latents: torch.Tensor) -> torch.Tensor: """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape # Ensure that Audio VAE is loaded on the correct device. self.device_manager.ensure_model_loaded() latents = self.device_manager.move_to_load_device(latents) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) waveform = self.run_vocoder(mel_spec) return self.device_manager.move_to_load_device(waveform) def target_shape_from_latents(self, latents_shape): batch, _, time, _ = latents_shape target_length = time * LATENT_DOWNSAMPLE_FACTOR if self.autoencoder.causality_axis != CausalityAxis.NONE: target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 return ( batch, self.autoencoder.decoder.out_ch, target_length, self.autoencoder.mel_bins, ) def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: audio_channels = self.autoencoder.decoder.out_ch vocoder_input = mel_spec.transpose(2, 3) if audio_channels == 1: vocoder_input = vocoder_input.squeeze(1) elif audio_channels != 2: raise ValueError(f"Unsupported audio_channels: {audio_channels}") return self.vocoder(vocoder_input) @property def sample_rate(self) -> int: return int(self.autoencoder.sampling_rate) @property def mel_hop_length(self) -> int: return int(self.autoencoder.mel_hop_length) @property def mel_bins(self) -> int: return int(self.autoencoder.mel_bins) @property def latent_channels(self) -> int: return int(self.autoencoder.decoder.z_channels) @property def latent_frequency_bins(self) -> int: return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) @property def latents_per_second(self) -> float: return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR @property def output_sample_rate(self) -> int: output_rate = getattr(self.vocoder, "output_sample_rate", None) if output_rate is not None: return int(output_rate) upsample_factor = getattr(self.vocoder, "upsample_factor", None) if upsample_factor is None: raise AttributeError( "Vocoder is missing upsample_factor; cannot infer output sample rate" ) return int(self.sample_rate * upsample_factor / self.mel_hop_length) def memory_required(self, input_shape): return self.device_manager.patcher.model_size()