mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 21:30:50 +08:00
184 lines
5.9 KiB
Python
184 lines
5.9 KiB
Python
import folder_paths
|
|
import comfy.utils
|
|
import comfy.model_management
|
|
import torch
|
|
|
|
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
|
|
from comfy_api.latest import ComfyExtension, io
|
|
|
|
|
|
class LTXVAudioVAELoader(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="LTXVAudioVAELoader",
|
|
display_name="LTXV Audio VAE Loader",
|
|
category="audio",
|
|
inputs=[
|
|
io.Combo.Input(
|
|
"ckpt_name",
|
|
options=folder_paths.get_filename_list("checkpoints"),
|
|
tooltip="Audio VAE checkpoint to load.",
|
|
)
|
|
],
|
|
outputs=[io.Vae.Output(display_name="Audio VAE")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, ckpt_name: str) -> io.NodeOutput:
|
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
|
return io.NodeOutput(AudioVAE(sd, metadata))
|
|
|
|
|
|
class LTXVAudioVAEEncode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="LTXVAudioVAEEncode",
|
|
display_name="LTXV Audio VAE Encode",
|
|
category="audio",
|
|
inputs=[
|
|
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
|
io.Vae.Input(
|
|
id="audio_vae",
|
|
display_name="Audio VAE",
|
|
tooltip="The Audio VAE model to use for encoding.",
|
|
),
|
|
],
|
|
outputs=[io.Latent.Output(display_name="Audio Latent")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
|
|
audio_latents = audio_vae.encode(audio)
|
|
return io.NodeOutput(
|
|
{
|
|
"samples": audio_latents,
|
|
"sample_rate": int(audio_vae.sample_rate),
|
|
"type": "audio",
|
|
}
|
|
)
|
|
|
|
|
|
class LTXVAudioVAEDecode(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="LTXVAudioVAEDecode",
|
|
display_name="LTXV Audio VAE Decode",
|
|
category="audio",
|
|
inputs=[
|
|
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
|
io.Vae.Input(
|
|
id="audio_vae",
|
|
display_name="Audio VAE",
|
|
tooltip="The Audio VAE model used for decoding the latent.",
|
|
),
|
|
],
|
|
outputs=[io.Audio.Output(display_name="Audio")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
|
|
audio_latent = samples["samples"]
|
|
if audio_latent.is_nested:
|
|
audio_latent = audio_latent.unbind()[-1]
|
|
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
|
|
output_audio_sample_rate = audio_vae.output_sample_rate
|
|
return io.NodeOutput(
|
|
{
|
|
"waveform": audio,
|
|
"sample_rate": int(output_audio_sample_rate),
|
|
}
|
|
)
|
|
|
|
|
|
class LTXVEmptyLatentAudio(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls) -> io.Schema:
|
|
return io.Schema(
|
|
node_id="LTXVEmptyLatentAudio",
|
|
display_name="LTXV Empty Latent Audio",
|
|
category="latent/audio",
|
|
inputs=[
|
|
io.Int.Input(
|
|
"frames_number",
|
|
default=97,
|
|
min=1,
|
|
max=1000,
|
|
step=1,
|
|
display_mode=io.NumberDisplay.number,
|
|
tooltip="Number of frames.",
|
|
),
|
|
io.Int.Input(
|
|
"frame_rate",
|
|
default=25,
|
|
min=1,
|
|
max=1000,
|
|
step=1,
|
|
display_mode=io.NumberDisplay.number,
|
|
tooltip="Number of frames per second.",
|
|
),
|
|
io.Int.Input(
|
|
"batch_size",
|
|
default=1,
|
|
min=1,
|
|
max=4096,
|
|
display_mode=io.NumberDisplay.number,
|
|
tooltip="The number of latent audio samples in the batch.",
|
|
),
|
|
io.Vae.Input(
|
|
id="audio_vae",
|
|
display_name="Audio VAE",
|
|
tooltip="The Audio VAE model to get configuration from.",
|
|
),
|
|
],
|
|
outputs=[io.Latent.Output(display_name="Latent")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(
|
|
cls,
|
|
frames_number: int,
|
|
frame_rate: int,
|
|
batch_size: int,
|
|
audio_vae: AudioVAE,
|
|
) -> io.NodeOutput:
|
|
"""Generate empty audio latents matching the reference pipeline structure."""
|
|
|
|
assert audio_vae is not None, "Audio VAE model is required"
|
|
|
|
z_channels = audio_vae.latent_channels
|
|
audio_freq = audio_vae.latent_frequency_bins
|
|
sampling_rate = int(audio_vae.sample_rate)
|
|
|
|
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
|
|
|
|
audio_latents = torch.zeros(
|
|
(batch_size, z_channels, num_audio_latents, audio_freq),
|
|
device=comfy.model_management.intermediate_device(),
|
|
)
|
|
|
|
return io.NodeOutput(
|
|
{
|
|
"samples": audio_latents,
|
|
"sample_rate": sampling_rate,
|
|
"type": "audio",
|
|
}
|
|
)
|
|
|
|
|
|
class LTXVAudioExtension(ComfyExtension):
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
LTXVAudioVAELoader,
|
|
LTXVAudioVAEEncode,
|
|
LTXVAudioVAEDecode,
|
|
LTXVEmptyLatentAudio,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> ComfyExtension:
|
|
return LTXVAudioExtension()
|