import folder_paths import comfy.utils import comfy.model_management import torch from comfy.ldm.lightricks.vae.audio_vae import AudioVAE from comfy_api.latest import ComfyExtension, io class LTXVAudioVAELoader(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="LTXVAudioVAELoader", display_name="LTXV Audio VAE Loader", category="audio", inputs=[ io.Combo.Input( "ckpt_name", options=folder_paths.get_filename_list("checkpoints"), tooltip="Audio VAE checkpoint to load.", ) ], outputs=[io.Vae.Output(display_name="Audio VAE")], ) @classmethod def execute(cls, ckpt_name: str) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) return io.NodeOutput(AudioVAE(sd, metadata)) class LTXVAudioVAEEncode(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="LTXVAudioVAEEncode", display_name="LTXV Audio VAE Encode", category="audio", inputs=[ io.Audio.Input("audio", tooltip="The audio to be encoded."), io.Vae.Input( id="audio_vae", display_name="Audio VAE", tooltip="The Audio VAE model to use for encoding.", ), ], outputs=[io.Latent.Output(display_name="Audio Latent")], ) @classmethod def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: audio_latents = audio_vae.encode(audio) return io.NodeOutput( { "samples": audio_latents, "sample_rate": int(audio_vae.sample_rate), "type": "audio", } ) class LTXVAudioVAEDecode(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="LTXVAudioVAEDecode", display_name="LTXV Audio VAE Decode", category="audio", inputs=[ io.Latent.Input("samples", tooltip="The latent to be decoded."), io.Vae.Input( id="audio_vae", display_name="Audio VAE", tooltip="The Audio VAE model used for decoding the latent.", ), ], outputs=[io.Audio.Output(display_name="Audio")], ) @classmethod def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: audio_latent = samples["samples"] if audio_latent.is_nested: audio_latent = audio_latent.unbind()[-1] audio = audio_vae.decode(audio_latent).to(audio_latent.device) output_audio_sample_rate = audio_vae.output_sample_rate return io.NodeOutput( { "waveform": audio, "sample_rate": int(output_audio_sample_rate), } ) class LTXVEmptyLatentAudio(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="LTXVEmptyLatentAudio", display_name="LTXV Empty Latent Audio", category="latent/audio", inputs=[ io.Int.Input( "frames_number", default=97, min=1, max=1000, step=1, display_mode=io.NumberDisplay.number, tooltip="Number of frames.", ), io.Int.Input( "frame_rate", default=25, min=1, max=1000, step=1, display_mode=io.NumberDisplay.number, tooltip="Number of frames per second.", ), io.Int.Input( "batch_size", default=1, min=1, max=4096, display_mode=io.NumberDisplay.number, tooltip="The number of latent audio samples in the batch.", ), io.Vae.Input( id="audio_vae", display_name="Audio VAE", tooltip="The Audio VAE model to get configuration from.", ), ], outputs=[io.Latent.Output(display_name="Latent")], ) @classmethod def execute( cls, frames_number: int, frame_rate: int, batch_size: int, audio_vae: AudioVAE, ) -> io.NodeOutput: """Generate empty audio latents matching the reference pipeline structure.""" assert audio_vae is not None, "Audio VAE model is required" z_channels = audio_vae.latent_channels audio_freq = audio_vae.latent_frequency_bins sampling_rate = int(audio_vae.sample_rate) num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) audio_latents = torch.zeros( (batch_size, z_channels, num_audio_latents, audio_freq), device=comfy.model_management.intermediate_device(), ) return io.NodeOutput( { "samples": audio_latents, "sample_rate": sampling_rate, "type": "audio", } ) class LTXVAudioExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ LTXVAudioVAELoader, LTXVAudioVAEEncode, LTXVAudioVAEDecode, LTXVEmptyLatentAudio, ] async def comfy_entrypoint() -> ComfyExtension: return LTXVAudioExtension()