mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 13:20:19 +08:00
For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation. This aligs inference with training and prevents loud audio from being attenuated.
280 lines
10 KiB
Python
280 lines
10 KiB
Python
import json
|
|
from dataclasses import dataclass
|
|
import math
|
|
import torch
|
|
import torchaudio
|
|
|
|
import comfy.model_management
|
|
import comfy.model_patcher
|
|
import comfy.utils as utils
|
|
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
|
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
|
CausalityAxis,
|
|
CausalAudioAutoencoder,
|
|
)
|
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
|
|
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AudioVAEComponentConfig:
|
|
"""Container for model component configuration extracted from metadata."""
|
|
|
|
autoencoder: dict
|
|
vocoder: dict
|
|
|
|
@classmethod
|
|
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
|
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
|
|
|
raw_config = metadata["config"]
|
|
if isinstance(raw_config, str):
|
|
parsed_config = json.loads(raw_config)
|
|
else:
|
|
parsed_config = raw_config
|
|
|
|
audio_config = parsed_config.get("audio_vae")
|
|
vocoder_config = parsed_config.get("vocoder")
|
|
|
|
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
|
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
|
|
|
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
|
|
|
|
|
class ModelDeviceManager:
|
|
"""Manages device placement and GPU residency for the composed model."""
|
|
|
|
def __init__(self, module: torch.nn.Module):
|
|
load_device = comfy.model_management.get_torch_device()
|
|
offload_device = comfy.model_management.vae_offload_device()
|
|
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
|
|
|
def ensure_model_loaded(self) -> None:
|
|
comfy.model_management.free_memory(
|
|
self.patcher.model_size(),
|
|
self.patcher.load_device,
|
|
)
|
|
comfy.model_management.load_model_gpu(self.patcher)
|
|
|
|
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
return tensor.to(self.patcher.load_device)
|
|
|
|
@property
|
|
def load_device(self):
|
|
return self.patcher.load_device
|
|
|
|
|
|
class AudioLatentNormalizer:
|
|
"""Applies per-channel statistics in patch space and restores original layout."""
|
|
|
|
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
|
self.patchifier = patchfier
|
|
self.statistics = statistics_processor
|
|
|
|
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
|
channels = latents.shape[1]
|
|
freq = latents.shape[3]
|
|
patched, _ = self.patchifier.patchify(latents)
|
|
normalized = self.statistics.normalize(patched)
|
|
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
|
|
|
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
|
channels = latents.shape[1]
|
|
freq = latents.shape[3]
|
|
patched, _ = self.patchifier.patchify(latents)
|
|
denormalized = self.statistics.un_normalize(patched)
|
|
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
|
|
|
|
|
class AudioPreprocessor:
|
|
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
|
|
|
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
|
self.target_sample_rate = target_sample_rate
|
|
self.mel_bins = mel_bins
|
|
self.mel_hop_length = mel_hop_length
|
|
self.n_fft = n_fft
|
|
|
|
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
|
if source_rate == self.target_sample_rate:
|
|
return waveform
|
|
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
|
|
|
def waveform_to_mel(
|
|
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
|
) -> torch.Tensor:
|
|
waveform = self.resample(waveform, waveform_sample_rate)
|
|
|
|
mel_transform = torchaudio.transforms.MelSpectrogram(
|
|
sample_rate=self.target_sample_rate,
|
|
n_fft=self.n_fft,
|
|
win_length=self.n_fft,
|
|
hop_length=self.mel_hop_length,
|
|
f_min=0.0,
|
|
f_max=self.target_sample_rate / 2.0,
|
|
n_mels=self.mel_bins,
|
|
window_fn=torch.hann_window,
|
|
center=True,
|
|
pad_mode="reflect",
|
|
power=1.0,
|
|
mel_scale="slaney",
|
|
norm="slaney",
|
|
).to(device)
|
|
|
|
mel = mel_transform(waveform)
|
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
|
return mel.permute(0, 1, 3, 2).contiguous()
|
|
|
|
|
|
class AudioVAE(torch.nn.Module):
|
|
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
|
|
|
def __init__(self, state_dict: dict, metadata: dict):
|
|
super().__init__()
|
|
|
|
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
|
|
|
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
|
|
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
|
|
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
|
|
|
autoencoder_config = self.autoencoder.get_config()
|
|
self.normalizer = AudioLatentNormalizer(
|
|
AudioPatchifier(
|
|
patch_size=1,
|
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
|
sample_rate=autoencoder_config["sampling_rate"],
|
|
hop_length=autoencoder_config["mel_hop_length"],
|
|
is_causal=autoencoder_config["is_causal"],
|
|
),
|
|
self.autoencoder.per_channel_statistics,
|
|
)
|
|
|
|
self.preprocessor = AudioPreprocessor(
|
|
target_sample_rate=autoencoder_config["sampling_rate"],
|
|
mel_bins=autoencoder_config["mel_bins"],
|
|
mel_hop_length=autoencoder_config["mel_hop_length"],
|
|
n_fft=autoencoder_config["n_fft"],
|
|
)
|
|
|
|
self.device_manager = ModelDeviceManager(self)
|
|
|
|
def encode(self, audio: dict) -> torch.Tensor:
|
|
"""Encode a waveform dictionary into normalized latent tensors."""
|
|
|
|
waveform = audio["waveform"]
|
|
waveform_sample_rate = audio["sample_rate"]
|
|
input_device = waveform.device
|
|
# Ensure that Audio VAE is loaded on the correct device.
|
|
self.device_manager.ensure_model_loaded()
|
|
|
|
waveform = self.device_manager.move_to_load_device(waveform)
|
|
expected_channels = self.autoencoder.encoder.in_channels
|
|
if waveform.shape[1] != expected_channels:
|
|
if waveform.shape[1] == 1:
|
|
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
|
|
else:
|
|
raise ValueError(
|
|
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
|
)
|
|
|
|
mel_spec = self.preprocessor.waveform_to_mel(
|
|
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
|
)
|
|
|
|
latents = self.autoencoder.encode(mel_spec)
|
|
posterior = DiagonalGaussianDistribution(latents)
|
|
latent_mode = posterior.mode()
|
|
|
|
normalized = self.normalizer.normalize(latent_mode)
|
|
return normalized.to(input_device)
|
|
|
|
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
|
"""Decode normalized latent tensors into an audio waveform."""
|
|
original_shape = latents.shape
|
|
|
|
# Ensure that Audio VAE is loaded on the correct device.
|
|
self.device_manager.ensure_model_loaded()
|
|
|
|
latents = self.device_manager.move_to_load_device(latents)
|
|
latents = self.normalizer.denormalize(latents)
|
|
|
|
target_shape = self.target_shape_from_latents(original_shape)
|
|
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
|
|
|
waveform = self.run_vocoder(mel_spec)
|
|
return self.device_manager.move_to_load_device(waveform)
|
|
|
|
def target_shape_from_latents(self, latents_shape):
|
|
batch, _, time, _ = latents_shape
|
|
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
|
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
|
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
|
return (
|
|
batch,
|
|
self.autoencoder.decoder.out_ch,
|
|
target_length,
|
|
self.autoencoder.mel_bins,
|
|
)
|
|
|
|
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
|
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
|
|
|
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
|
audio_channels = self.autoencoder.decoder.out_ch
|
|
vocoder_input = mel_spec.transpose(2, 3)
|
|
|
|
if audio_channels == 1:
|
|
vocoder_input = vocoder_input.squeeze(1)
|
|
elif audio_channels != 2:
|
|
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
|
|
|
return self.vocoder(vocoder_input)
|
|
|
|
@property
|
|
def sample_rate(self) -> int:
|
|
return int(self.autoencoder.sampling_rate)
|
|
|
|
@property
|
|
def mel_hop_length(self) -> int:
|
|
return int(self.autoencoder.mel_hop_length)
|
|
|
|
@property
|
|
def mel_bins(self) -> int:
|
|
return int(self.autoencoder.mel_bins)
|
|
|
|
@property
|
|
def latent_channels(self) -> int:
|
|
return int(self.autoencoder.decoder.z_channels)
|
|
|
|
@property
|
|
def latent_frequency_bins(self) -> int:
|
|
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
|
|
|
@property
|
|
def latents_per_second(self) -> float:
|
|
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
|
|
|
@property
|
|
def output_sample_rate(self) -> int:
|
|
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
|
if output_rate is not None:
|
|
return int(output_rate)
|
|
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
|
if upsample_factor is None:
|
|
raise AttributeError(
|
|
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
|
)
|
|
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
|
|
|
def memory_required(self, input_shape):
|
|
return self.device_manager.patcher.model_size()
|