mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
287 lines
10 KiB
Python
287 lines
10 KiB
Python
import json
|
|
from dataclasses import dataclass
|
|
import math
|
|
import torch
|
|
import torchaudio
|
|
|
|
import comfy.model_management
|
|
import comfy.model_patcher
|
|
import comfy.utils as utils
|
|
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
|
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
|
CausalityAxis,
|
|
CausalAudioAutoencoder,
|
|
)
|
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
|
|
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AudioVAEComponentConfig:
|
|
"""Container for model component configuration extracted from metadata."""
|
|
|
|
autoencoder: dict
|
|
vocoder: dict
|
|
|
|
@classmethod
|
|
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
|
|
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
|
|
|
|
raw_config = metadata["config"]
|
|
if isinstance(raw_config, str):
|
|
parsed_config = json.loads(raw_config)
|
|
else:
|
|
parsed_config = raw_config
|
|
|
|
audio_config = parsed_config.get("audio_vae")
|
|
vocoder_config = parsed_config.get("vocoder")
|
|
|
|
assert audio_config is not None, "Audio VAE config is required for audio VAE"
|
|
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
|
|
|
|
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
|
|
|
|
|
class ModelDeviceManager:
|
|
"""Manages device placement and GPU residency for the composed model."""
|
|
|
|
def __init__(self, module: torch.nn.Module):
|
|
load_device = comfy.model_management.get_torch_device()
|
|
offload_device = comfy.model_management.vae_offload_device()
|
|
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
|
|
|
def ensure_model_loaded(self) -> None:
|
|
comfy.model_management.free_memory(
|
|
self.patcher.model_size(),
|
|
self.patcher.load_device,
|
|
)
|
|
comfy.model_management.load_model_gpu(self.patcher)
|
|
|
|
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
return tensor.to(self.patcher.load_device)
|
|
|
|
@property
|
|
def load_device(self):
|
|
return self.patcher.load_device
|
|
|
|
|
|
class AudioLatentNormalizer:
|
|
"""Applies per-channel statistics in patch space and restores original layout."""
|
|
|
|
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
|
|
self.patchifier = patchfier
|
|
self.statistics = statistics_processor
|
|
|
|
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
|
|
channels = latents.shape[1]
|
|
freq = latents.shape[3]
|
|
patched, _ = self.patchifier.patchify(latents)
|
|
normalized = self.statistics.normalize(patched)
|
|
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
|
|
|
|
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
|
|
channels = latents.shape[1]
|
|
freq = latents.shape[3]
|
|
patched, _ = self.patchifier.patchify(latents)
|
|
denormalized = self.statistics.un_normalize(patched)
|
|
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
|
|
|
|
|
|
class AudioPreprocessor:
|
|
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
|
|
|
|
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
|
|
self.target_sample_rate = target_sample_rate
|
|
self.mel_bins = mel_bins
|
|
self.mel_hop_length = mel_hop_length
|
|
self.n_fft = n_fft
|
|
|
|
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
|
|
if source_rate == self.target_sample_rate:
|
|
return waveform
|
|
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
|
|
|
@staticmethod
|
|
def normalize_amplitude(
|
|
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
|
) -> torch.Tensor:
|
|
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
|
peak = torch.max(torch.abs(waveform)) + eps
|
|
scale = peak.clamp(max=max_amplitude) / peak
|
|
return waveform * scale
|
|
|
|
def waveform_to_mel(
|
|
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
|
) -> torch.Tensor:
|
|
waveform = self.resample(waveform, waveform_sample_rate)
|
|
waveform = self.normalize_amplitude(waveform)
|
|
|
|
mel_transform = torchaudio.transforms.MelSpectrogram(
|
|
sample_rate=self.target_sample_rate,
|
|
n_fft=self.n_fft,
|
|
win_length=self.n_fft,
|
|
hop_length=self.mel_hop_length,
|
|
f_min=0.0,
|
|
f_max=self.target_sample_rate / 2.0,
|
|
n_mels=self.mel_bins,
|
|
window_fn=torch.hann_window,
|
|
center=True,
|
|
pad_mode="reflect",
|
|
power=1.0,
|
|
mel_scale="slaney",
|
|
norm="slaney",
|
|
).to(device)
|
|
|
|
mel = mel_transform(waveform)
|
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
|
return mel.permute(0, 1, 3, 2).contiguous()
|
|
|
|
|
|
class AudioVAE(torch.nn.Module):
|
|
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
|
|
|
def __init__(self, state_dict: dict, metadata: dict):
|
|
super().__init__()
|
|
|
|
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
|
|
|
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
|
|
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
|
|
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
|
|
|
autoencoder_config = self.autoencoder.get_config()
|
|
self.normalizer = AudioLatentNormalizer(
|
|
AudioPatchifier(
|
|
patch_size=1,
|
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
|
sample_rate=autoencoder_config["sampling_rate"],
|
|
hop_length=autoencoder_config["mel_hop_length"],
|
|
is_causal=autoencoder_config["is_causal"],
|
|
),
|
|
self.autoencoder.per_channel_statistics,
|
|
)
|
|
|
|
self.preprocessor = AudioPreprocessor(
|
|
target_sample_rate=autoencoder_config["sampling_rate"],
|
|
mel_bins=autoencoder_config["mel_bins"],
|
|
mel_hop_length=autoencoder_config["mel_hop_length"],
|
|
n_fft=autoencoder_config["n_fft"],
|
|
)
|
|
|
|
self.device_manager = ModelDeviceManager(self)
|
|
|
|
def encode(self, audio: dict) -> torch.Tensor:
|
|
"""Encode a waveform dictionary into normalized latent tensors."""
|
|
|
|
waveform = audio["waveform"]
|
|
waveform_sample_rate = audio["sample_rate"]
|
|
input_device = waveform.device
|
|
# Ensure that Audio VAE is loaded on the correct device.
|
|
self.device_manager.ensure_model_loaded()
|
|
|
|
waveform = self.device_manager.move_to_load_device(waveform)
|
|
expected_channels = self.autoencoder.encoder.in_channels
|
|
if waveform.shape[1] != expected_channels:
|
|
raise ValueError(
|
|
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
|
)
|
|
|
|
mel_spec = self.preprocessor.waveform_to_mel(
|
|
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
|
)
|
|
|
|
latents = self.autoencoder.encode(mel_spec)
|
|
posterior = DiagonalGaussianDistribution(latents)
|
|
latent_mode = posterior.mode()
|
|
|
|
normalized = self.normalizer.normalize(latent_mode)
|
|
return normalized.to(input_device)
|
|
|
|
def decode(self, latents: torch.Tensor) -> torch.Tensor:
|
|
"""Decode normalized latent tensors into an audio waveform."""
|
|
original_shape = latents.shape
|
|
|
|
# Ensure that Audio VAE is loaded on the correct device.
|
|
self.device_manager.ensure_model_loaded()
|
|
|
|
latents = self.device_manager.move_to_load_device(latents)
|
|
latents = self.normalizer.denormalize(latents)
|
|
|
|
target_shape = self.target_shape_from_latents(original_shape)
|
|
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
|
|
|
waveform = self.run_vocoder(mel_spec)
|
|
return self.device_manager.move_to_load_device(waveform)
|
|
|
|
def target_shape_from_latents(self, latents_shape):
|
|
batch, _, time, _ = latents_shape
|
|
target_length = time * LATENT_DOWNSAMPLE_FACTOR
|
|
if self.autoencoder.causality_axis != CausalityAxis.NONE:
|
|
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
|
|
return (
|
|
batch,
|
|
self.autoencoder.decoder.out_ch,
|
|
target_length,
|
|
self.autoencoder.mel_bins,
|
|
)
|
|
|
|
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
|
|
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
|
|
|
|
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
|
audio_channels = self.autoencoder.decoder.out_ch
|
|
vocoder_input = mel_spec.transpose(2, 3)
|
|
|
|
if audio_channels == 1:
|
|
vocoder_input = vocoder_input.squeeze(1)
|
|
elif audio_channels != 2:
|
|
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
|
|
|
|
return self.vocoder(vocoder_input)
|
|
|
|
@property
|
|
def sample_rate(self) -> int:
|
|
return int(self.autoencoder.sampling_rate)
|
|
|
|
@property
|
|
def mel_hop_length(self) -> int:
|
|
return int(self.autoencoder.mel_hop_length)
|
|
|
|
@property
|
|
def mel_bins(self) -> int:
|
|
return int(self.autoencoder.mel_bins)
|
|
|
|
@property
|
|
def latent_channels(self) -> int:
|
|
return int(self.autoencoder.decoder.z_channels)
|
|
|
|
@property
|
|
def latent_frequency_bins(self) -> int:
|
|
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
|
|
|
|
@property
|
|
def latents_per_second(self) -> float:
|
|
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
|
|
|
|
@property
|
|
def output_sample_rate(self) -> int:
|
|
output_rate = getattr(self.vocoder, "output_sample_rate", None)
|
|
if output_rate is not None:
|
|
return int(output_rate)
|
|
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
|
|
if upsample_factor is None:
|
|
raise AttributeError(
|
|
"Vocoder is missing upsample_factor; cannot infer output sample rate"
|
|
)
|
|
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
|
|
|
|
def memory_required(self, input_shape):
|
|
return self.device_manager.patcher.model_size()
|