From bfc70a77356bacf33ad8e49ac3d783db43cafddc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Apr 2026 22:57:07 -0400 Subject: [PATCH] Make the ltx audio vae more native. --- comfy/ldm/lightricks/vae/audio_vae.py | 55 +++------------------------ comfy/sd.py | 18 +++++++++ comfy_extras/nodes_audio.py | 2 +- comfy_extras/nodes_lt_audio.py | 36 ++++++++---------- 4 files changed, 41 insertions(+), 70 deletions(-) diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index fa0a00748..dd5320c8f 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -4,9 +4,6 @@ import math import torch import torchaudio -import comfy.model_management -import comfy.model_patcher -import comfy.utils as utils from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( @@ -43,30 +40,6 @@ class AudioVAEComponentConfig: return cls(autoencoder=audio_config, vocoder=vocoder_config) - -class ModelDeviceManager: - """Manages device placement and GPU residency for the composed model.""" - - def __init__(self, module: torch.nn.Module): - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) - - def ensure_model_loaded(self) -> None: - comfy.model_management.free_memory( - self.patcher.model_size(), - self.patcher.load_device, - ) - comfy.model_management.load_model_gpu(self.patcher) - - def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(self.patcher.load_device) - - @property - def load_device(self): - return self.patcher.load_device - - class AudioLatentNormalizer: """Applies per-channel statistics in patch space and restores original layout.""" @@ -132,23 +105,17 @@ class AudioPreprocessor: class AudioVAE(torch.nn.Module): """High-level Audio VAE wrapper exposing encode and decode entry points.""" - def __init__(self, state_dict: dict, metadata: dict): + def __init__(self, metadata: dict): super().__init__() component_config = AudioVAEComponentConfig.from_metadata(metadata) - vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) - vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) - self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) if "bwe" in component_config.vocoder: self.vocoder = VocoderWithBWE(config=component_config.vocoder) else: self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) - autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( AudioPatchifier( @@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module): n_fft=autoencoder_config["n_fft"], ) - self.device_manager = ModelDeviceManager(self) - - def encode(self, audio: dict) -> torch.Tensor: + def encode(self, audio, sample_rate=44100) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" - waveform = audio["waveform"] - waveform_sample_rate = audio["sample_rate"] + waveform = audio + waveform_sample_rate = sample_rate input_device = waveform.device - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: if waveform.shape[1] == 1: @@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module): ) mel_spec = self.preprocessor.waveform_to_mel( - waveform, waveform_sample_rate, device=self.device_manager.load_device + waveform, waveform_sample_rate, device=waveform.device ) latents = self.autoencoder.encode(mel_spec) @@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module): """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - latents = self.device_manager.move_to_load_device(latents) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) waveform = self.run_vocoder(mel_spec) - return self.device_manager.move_to_load_device(waveform) + return waveform def target_shape_from_latents(self, latents_shape): batch, _, time, _ = latents_shape diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..a4d3ee269 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder +import comfy.ldm.lightricks.vae.audio_vae import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 @@ -805,6 +806,23 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio + self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata) + self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) + self.latent_channels = self.first_stage_model.latent_channels + self.audio_sample_rate_output = self.first_stage_model.output_sample_rate + self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes + self.output_channels = 2 + self.pad_channel_value = "replicate" + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 2 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a395392d8..5f514716f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None): std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100)) return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3e4222264..3ec635c75 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -3,9 +3,8 @@ import comfy.utils import comfy.model_management import torch -from comfy.ldm.lightricks.vae.audio_vae import AudioVAE from comfy_api.latest import ComfyExtension, io - +from comfy_extras.nodes_audio import VAEEncodeAudio class LTXVAudioVAELoader(io.ComfyNode): @classmethod @@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode): def execute(cls, ckpt_name: str) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - return io.NodeOutput(AudioVAE(sd, metadata)) + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) + vae.throw_exception_if_invalid() + + return io.NodeOutput(vae) -class LTXVAudioVAEEncode(io.ComfyNode): +class LTXVAudioVAEEncode(VAEEncodeAudio): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( @@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode): ) @classmethod - def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: - audio_latents = audio_vae.encode(audio) - return io.NodeOutput( - { - "samples": audio_latents, - "sample_rate": int(audio_vae.sample_rate), - "type": "audio", - } - ) + def execute(cls, audio, audio_vae) -> io.NodeOutput: + return super().execute(audio_vae, audio) class LTXVAudioVAEDecode(io.ComfyNode): @@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode): ) @classmethod - def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + def execute(cls, samples, audio_vae) -> io.NodeOutput: audio_latent = samples["samples"] if audio_latent.is_nested: audio_latent = audio_latent.unbind()[-1] - audio = audio_vae.decode(audio_latent).to(audio_latent.device) - output_audio_sample_rate = audio_vae.output_sample_rate + audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device) + output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate return io.NodeOutput( { "waveform": audio, @@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode): frames_number: int, frame_rate: int, batch_size: int, - audio_vae: AudioVAE, + audio_vae, ) -> io.NodeOutput: """Generate empty audio latents matching the reference pipeline structure.""" assert audio_vae is not None, "Audio VAE model is required" z_channels = audio_vae.latent_channels - audio_freq = audio_vae.latent_frequency_bins - sampling_rate = int(audio_vae.sample_rate) + audio_freq = audio_vae.first_stage_model.latent_frequency_bins + sampling_rate = int(audio_vae.first_stage_model.sample_rate) - num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) audio_latents = torch.zeros( (batch_size, z_channels, num_audio_latents, audio_freq),