Make the ltx audio vae more native.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
comfyanonymous 2026-04-20 22:57:07 -04:00
parent c514890325
commit bfc70a7735
4 changed files with 41 additions and 70 deletions

View File

@ -4,9 +4,6 @@ import math
import torch import torch
import torchaudio import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
@ -43,30 +40,6 @@ class AudioVAEComponentConfig:
return cls(autoencoder=audio_config, vocoder=vocoder_config) return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer: class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout.""" """Applies per-channel statistics in patch space and restores original layout."""
@ -132,23 +105,17 @@ class AudioPreprocessor:
class AudioVAE(torch.nn.Module): class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points.""" """High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict): def __init__(self, metadata: dict):
super().__init__() super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata) component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
if "bwe" in component_config.vocoder: if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder) self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else: else:
self.vocoder = Vocoder(config=component_config.vocoder) self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config() autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer( self.normalizer = AudioLatentNormalizer(
AudioPatchifier( AudioPatchifier(
@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module):
n_fft=autoencoder_config["n_fft"], n_fft=autoencoder_config["n_fft"],
) )
self.device_manager = ModelDeviceManager(self) def encode(self, audio, sample_rate=44100) -> torch.Tensor:
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors.""" """Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"] waveform = audio
waveform_sample_rate = audio["sample_rate"] waveform_sample_rate = sample_rate
input_device = waveform.device input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels: if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1: if waveform.shape[1] == 1:
@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module):
) )
mel_spec = self.preprocessor.waveform_to_mel( mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device waveform, waveform_sample_rate, device=waveform.device
) )
latents = self.autoencoder.encode(mel_spec) latents = self.autoencoder.encode(mel_spec)
@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module):
"""Decode normalized latent tensors into an audio waveform.""" """Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents) latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape) target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec) waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform) return waveform
def target_shape_from_latents(self, latents_shape): def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape batch, _, time, _ = latents_shape

View File

@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.lightricks.vae.audio_vae
import comfy.ldm.cosmos.vae import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
@ -805,6 +806,23 @@ class VAE:
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = self.first_stage_model.latent_channels
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None

View File

@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0 std[std < 1.0] = 1.0
audio /= std audio /= std
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}

View File

@ -3,9 +3,8 @@ import comfy.utils
import comfy.model_management import comfy.model_management
import torch import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_audio import VAEEncodeAudio
class LTXVAudioVAELoader(io.ComfyNode): class LTXVAudioVAELoader(io.ComfyNode):
@classmethod @classmethod
@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode):
def execute(cls, ckpt_name: str) -> io.NodeOutput: def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
return io.NodeOutput(AudioVAE(sd, metadata)) sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
return io.NodeOutput(vae)
class LTXVAudioVAEEncode(io.ComfyNode): class LTXVAudioVAEEncode(VAEEncodeAudio):
@classmethod @classmethod
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: def execute(cls, audio, audio_vae) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio) return super().execute(audio_vae, audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
class LTXVAudioVAEDecode(io.ComfyNode): class LTXVAudioVAEDecode(io.ComfyNode):
@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: def execute(cls, samples, audio_vae) -> io.NodeOutput:
audio_latent = samples["samples"] audio_latent = samples["samples"]
if audio_latent.is_nested: if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1] audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).to(audio_latent.device) audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
return io.NodeOutput( return io.NodeOutput(
{ {
"waveform": audio, "waveform": audio,
@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
frames_number: int, frames_number: int,
frame_rate: int, frame_rate: int,
batch_size: int, batch_size: int,
audio_vae: AudioVAE, audio_vae,
) -> io.NodeOutput: ) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure.""" """Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required" assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels z_channels = audio_vae.latent_channels
audio_freq = audio_vae.latent_frequency_bins audio_freq = audio_vae.first_stage_model.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate) sampling_rate = int(audio_vae.first_stage_model.sample_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros( audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq), (batch_size, z_channels, num_audio_latents, audio_freq),