mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-23 08:52:32 +08:00
Make the ltx audio vae more native.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
parent
c514890325
commit
bfc70a7735
@ -4,9 +4,6 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
import comfy.model_management
|
|
||||||
import comfy.model_patcher
|
|
||||||
import comfy.utils as utils
|
|
||||||
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||||
@ -43,30 +40,6 @@ class AudioVAEComponentConfig:
|
|||||||
|
|
||||||
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
return cls(autoencoder=audio_config, vocoder=vocoder_config)
|
||||||
|
|
||||||
|
|
||||||
class ModelDeviceManager:
|
|
||||||
"""Manages device placement and GPU residency for the composed model."""
|
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module):
|
|
||||||
load_device = comfy.model_management.get_torch_device()
|
|
||||||
offload_device = comfy.model_management.vae_offload_device()
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
|
|
||||||
|
|
||||||
def ensure_model_loaded(self) -> None:
|
|
||||||
comfy.model_management.free_memory(
|
|
||||||
self.patcher.model_size(),
|
|
||||||
self.patcher.load_device,
|
|
||||||
)
|
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
|
||||||
|
|
||||||
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
return tensor.to(self.patcher.load_device)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def load_device(self):
|
|
||||||
return self.patcher.load_device
|
|
||||||
|
|
||||||
|
|
||||||
class AudioLatentNormalizer:
|
class AudioLatentNormalizer:
|
||||||
"""Applies per-channel statistics in patch space and restores original layout."""
|
"""Applies per-channel statistics in patch space and restores original layout."""
|
||||||
|
|
||||||
@ -132,23 +105,17 @@ class AudioPreprocessor:
|
|||||||
class AudioVAE(torch.nn.Module):
|
class AudioVAE(torch.nn.Module):
|
||||||
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
|
||||||
|
|
||||||
def __init__(self, state_dict: dict, metadata: dict):
|
def __init__(self, metadata: dict):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
component_config = AudioVAEComponentConfig.from_metadata(metadata)
|
||||||
|
|
||||||
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
|
|
||||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
|
||||||
|
|
||||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
if "bwe" in component_config.vocoder:
|
if "bwe" in component_config.vocoder:
|
||||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||||
else:
|
else:
|
||||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
|
||||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
|
||||||
|
|
||||||
autoencoder_config = self.autoencoder.get_config()
|
autoencoder_config = self.autoencoder.get_config()
|
||||||
self.normalizer = AudioLatentNormalizer(
|
self.normalizer = AudioLatentNormalizer(
|
||||||
AudioPatchifier(
|
AudioPatchifier(
|
||||||
@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module):
|
|||||||
n_fft=autoencoder_config["n_fft"],
|
n_fft=autoencoder_config["n_fft"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device_manager = ModelDeviceManager(self)
|
def encode(self, audio, sample_rate=44100) -> torch.Tensor:
|
||||||
|
|
||||||
def encode(self, audio: dict) -> torch.Tensor:
|
|
||||||
"""Encode a waveform dictionary into normalized latent tensors."""
|
"""Encode a waveform dictionary into normalized latent tensors."""
|
||||||
|
|
||||||
waveform = audio["waveform"]
|
waveform = audio
|
||||||
waveform_sample_rate = audio["sample_rate"]
|
waveform_sample_rate = sample_rate
|
||||||
input_device = waveform.device
|
input_device = waveform.device
|
||||||
# Ensure that Audio VAE is loaded on the correct device.
|
|
||||||
self.device_manager.ensure_model_loaded()
|
|
||||||
|
|
||||||
waveform = self.device_manager.move_to_load_device(waveform)
|
|
||||||
expected_channels = self.autoencoder.encoder.in_channels
|
expected_channels = self.autoencoder.encoder.in_channels
|
||||||
if waveform.shape[1] != expected_channels:
|
if waveform.shape[1] != expected_channels:
|
||||||
if waveform.shape[1] == 1:
|
if waveform.shape[1] == 1:
|
||||||
@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mel_spec = self.preprocessor.waveform_to_mel(
|
mel_spec = self.preprocessor.waveform_to_mel(
|
||||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
waveform, waveform_sample_rate, device=waveform.device
|
||||||
)
|
)
|
||||||
|
|
||||||
latents = self.autoencoder.encode(mel_spec)
|
latents = self.autoencoder.encode(mel_spec)
|
||||||
@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module):
|
|||||||
"""Decode normalized latent tensors into an audio waveform."""
|
"""Decode normalized latent tensors into an audio waveform."""
|
||||||
original_shape = latents.shape
|
original_shape = latents.shape
|
||||||
|
|
||||||
# Ensure that Audio VAE is loaded on the correct device.
|
|
||||||
self.device_manager.ensure_model_loaded()
|
|
||||||
|
|
||||||
latents = self.device_manager.move_to_load_device(latents)
|
|
||||||
latents = self.normalizer.denormalize(latents)
|
latents = self.normalizer.denormalize(latents)
|
||||||
|
|
||||||
target_shape = self.target_shape_from_latents(original_shape)
|
target_shape = self.target_shape_from_latents(original_shape)
|
||||||
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
|
||||||
|
|
||||||
waveform = self.run_vocoder(mel_spec)
|
waveform = self.run_vocoder(mel_spec)
|
||||||
return self.device_manager.move_to_load_device(waveform)
|
return waveform
|
||||||
|
|
||||||
def target_shape_from_latents(self, latents_shape):
|
def target_shape_from_latents(self, latents_shape):
|
||||||
batch, _, time, _ = latents_shape
|
batch, _, time, _ = latents_shape
|
||||||
|
|||||||
18
comfy/sd.py
18
comfy/sd.py
@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder
|
|||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
import comfy.ldm.genmo.vae.model
|
import comfy.ldm.genmo.vae.model
|
||||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||||
|
import comfy.ldm.lightricks.vae.audio_vae
|
||||||
import comfy.ldm.cosmos.vae
|
import comfy.ldm.cosmos.vae
|
||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
import comfy.ldm.wan.vae2_2
|
import comfy.ldm.wan.vae2_2
|
||||||
@ -805,6 +806,23 @@ class VAE:
|
|||||||
self.downscale_index_formula = (4, 8, 8)
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||||
|
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
|
||||||
|
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||||
|
self.latent_channels = self.first_stage_model.latent_channels
|
||||||
|
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
|
||||||
|
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
|
||||||
|
self.output_channels = 2
|
||||||
|
self.pad_channel_value = "replicate"
|
||||||
|
self.upscale_ratio = 4096
|
||||||
|
self.downscale_ratio = 4096
|
||||||
|
self.latent_dim = 2
|
||||||
|
self.process_output = lambda audio: audio
|
||||||
|
self.process_input = lambda audio: audio
|
||||||
|
self.working_dtypes = [torch.float32]
|
||||||
|
self.disable_offload = True
|
||||||
|
self.extra_1d_channel = 16
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
|
|||||||
@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None):
|
|||||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||||
std[std < 1.0] = 1.0
|
std[std < 1.0] = 1.0
|
||||||
audio /= std
|
audio /= std
|
||||||
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
|
vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100))
|
||||||
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,8 @@ import comfy.utils
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_extras.nodes_audio import VAEEncodeAudio
|
||||||
|
|
||||||
class LTXVAudioVAELoader(io.ComfyNode):
|
class LTXVAudioVAELoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
|||||||
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
||||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
return io.NodeOutput(AudioVAE(sd, metadata))
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True)
|
||||||
|
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||||
|
vae.throw_exception_if_invalid()
|
||||||
|
|
||||||
|
return io.NodeOutput(vae)
|
||||||
|
|
||||||
|
|
||||||
class LTXVAudioVAEEncode(io.ComfyNode):
|
class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
def execute(cls, audio, audio_vae) -> io.NodeOutput:
|
||||||
audio_latents = audio_vae.encode(audio)
|
return super().execute(audio_vae, audio)
|
||||||
return io.NodeOutput(
|
|
||||||
{
|
|
||||||
"samples": audio_latents,
|
|
||||||
"sample_rate": int(audio_vae.sample_rate),
|
|
||||||
"type": "audio",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LTXVAudioVAEDecode(io.ComfyNode):
|
class LTXVAudioVAEDecode(io.ComfyNode):
|
||||||
@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
def execute(cls, samples, audio_vae) -> io.NodeOutput:
|
||||||
audio_latent = samples["samples"]
|
audio_latent = samples["samples"]
|
||||||
if audio_latent.is_nested:
|
if audio_latent.is_nested:
|
||||||
audio_latent = audio_latent.unbind()[-1]
|
audio_latent = audio_latent.unbind()[-1]
|
||||||
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device)
|
||||||
output_audio_sample_rate = audio_vae.output_sample_rate
|
output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate
|
||||||
return io.NodeOutput(
|
return io.NodeOutput(
|
||||||
{
|
{
|
||||||
"waveform": audio,
|
"waveform": audio,
|
||||||
@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
|||||||
frames_number: int,
|
frames_number: int,
|
||||||
frame_rate: int,
|
frame_rate: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
audio_vae: AudioVAE,
|
audio_vae,
|
||||||
) -> io.NodeOutput:
|
) -> io.NodeOutput:
|
||||||
"""Generate empty audio latents matching the reference pipeline structure."""
|
"""Generate empty audio latents matching the reference pipeline structure."""
|
||||||
|
|
||||||
assert audio_vae is not None, "Audio VAE model is required"
|
assert audio_vae is not None, "Audio VAE model is required"
|
||||||
|
|
||||||
z_channels = audio_vae.latent_channels
|
z_channels = audio_vae.latent_channels
|
||||||
audio_freq = audio_vae.latent_frequency_bins
|
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
|
||||||
sampling_rate = int(audio_vae.sample_rate)
|
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
|
||||||
|
|
||||||
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
|
||||||
|
|
||||||
audio_latents = torch.zeros(
|
audio_latents = torch.zeros(
|
||||||
(batch_size, z_channels, num_audio_latents, audio_freq),
|
(batch_size, z_channels, num_audio_latents, audio_freq),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user