ComfyUI/comfy_extras/nodes_lt_audio.py
Jedrzej Kosinski aa464b36b3
Multi-GPU device selection for loader nodes + CUDA context fixes (#13483)
* Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2)

Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db
Co-authored-by: Amp <amp@ampcode.com>

* Add GPU device selection to all loader nodes

- Add get_gpu_device_options() and resolve_gpu_device_option() helpers
  in model_management.py for vendor-agnostic GPU device selection
- Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader
- Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader
  from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems
- Wire load_diffusion_model_state_dict and load_state_dict_guess_config
  to respect model_options['load_device']
- Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU)
  silently fall back to default

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Add VALIDATE_INPUTS to skip device combo validation for workflow portability

When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded
on a 1-GPU machine, the combo validation would reject the unknown value.
VALIDATE_INPUTS with the device parameter bypasses combo validation for
that input only, allowing resolve_gpu_device_option to handle the
graceful fallback at runtime.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Set CUDA device context in outer_sample to match model load_device

Custom CUDA kernels (comfy_kitchen fp8 quantization) use
torch.cuda.current_device() for DLPack tensor export. When a model is
loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match
or the kernel fails with 'Can't export tensors on a different CUDA
device index'. Save and restore the previous device around sampling.

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options

- resolve_gpu_device_option: reject negative indices (gpu:-1)
- UNETLoader: set offload_device when cpu is selected
- CheckpointLoaderSimple: pass te_model_options for CLIP device,
  set offload_device for cpu, pass load_device to VAE
- load_diffusion_model_state_dict: respect offload_device from model_options
- load_state_dict_guess_config: respect offload_device, pass load_device to VAE

Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a
Co-authored-by: Amp <amp@ampcode.com>

* Fix CUDA device context for CLIP encoding and VAE encode/decode

Add torch.cuda.set_device() calls to match model's load device in:
- CLIP.encode_from_tokens: fixes 'Can't export tensors on a different
  CUDA device index' when CLIP is loaded on a non-default GPU
- CLIP.encode_from_tokens_scheduled: same fix for the hooks code path
- CLIP.generate: same fix for text generation
- VAE.decode: fixes VAE decoding on non-default GPU
- VAE.encode: fixes VAE encoding on non-default GPU

Same pattern as the existing outer_sample fix in samplers.py - saves
and restores previous CUDA device in a try/finally block.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

* Extract cuda_device_context manager, fix tiled VAE methods

Add model_management.cuda_device_context() — a context manager that
saves/restores torch.cuda.current_device when operating on a non-default
GPU. Replaces 6 copies of the manual save/set/restore boilerplate.

Refactored call sites:
- CLIP.encode_from_tokens
- CLIP.encode_from_tokens_scheduled (hooks path)
- CLIP.generate
- VAE.decode
- VAE.encode
- samplers.outer_sample

Bug fixes (newly wrapped):
- VAE.decode_tiled: was missing device context entirely, would fail
  on non-default GPU when called from 'VAE Decode (Tiled)' node
- VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

* Restore CheckpointLoaderSimple, add CheckpointLoaderDevice

Revert CheckpointLoaderSimple to its original form (no device input)
so it remains the simple default loader.

Add new CheckpointLoaderDevice node (advanced/loaders) with separate
model_device, clip_device, and vae_device inputs for per-component
GPU placement in multi-GPU setups.

Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57
Co-authored-by: Amp <amp@ampcode.com>

---------

Co-authored-by: Amp <amp@ampcode.com>
2026-04-23 19:10:33 -07:00

230 lines
7.7 KiB
Python

import folder_paths
import comfy.utils
import comfy.model_management
import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io
class LTXVAudioVAELoader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="LTXV Audio VAE Loader",
category="audio",
inputs=[
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
tooltip="Audio VAE checkpoint to load.",
)
],
outputs=[io.Vae.Output(display_name="Audio VAE")],
)
@classmethod
def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
return io.NodeOutput(AudioVAE(sd, metadata))
class LTXVAudioVAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model to use for encoding.",
),
],
outputs=[io.Latent.Output(display_name="Audio Latent")],
)
@classmethod
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
class LTXVAudioVAEDecode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model used for decoding the latent.",
),
],
outputs=[io.Audio.Output(display_name="Audio")],
)
@classmethod
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latent = samples["samples"]
if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate
return io.NodeOutput(
{
"waveform": audio,
"sample_rate": int(output_audio_sample_rate),
}
)
class LTXVEmptyLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVEmptyLatentAudio",
display_name="LTXV Empty Latent Audio",
category="latent/audio",
inputs=[
io.Int.Input(
"frames_number",
default=97,
min=1,
max=1000,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Number of frames.",
),
io.Int.Input(
"frame_rate",
default=25,
min=1,
max=1000,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Number of frames per second.",
),
io.Int.Input(
"batch_size",
default=1,
min=1,
max=4096,
display_mode=io.NumberDisplay.number,
tooltip="The number of latent audio samples in the batch.",
),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model to get configuration from.",
),
],
outputs=[io.Latent.Output(display_name="Latent")],
)
@classmethod
def execute(
cls,
frames_number: int,
frame_rate: int,
batch_size: int,
audio_vae: AudioVAE,
) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels
audio_freq = audio_vae.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": sampling_rate,
"type": "audio",
}
)
class LTXAVTextEncoderLoader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXAVTextEncoderLoader",
display_name="LTXV Audio Text Encoder Loader",
category="advanced/loaders",
description="[Recipes]\n\nltxav: gemma 3 12B",
inputs=[
io.Combo.Input(
"text_encoder",
options=folder_paths.get_filename_list("text_encoders"),
),
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
),
io.Combo.Input(
"device",
options=comfy.model_management.get_gpu_device_options(),
advanced=True,
)
],
outputs=[io.Clip.Output()],
)
@classmethod
def execute(cls, text_encoder, ckpt_name, device="default"):
clip_type = comfy.sd.CLIPType.LTXV
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
resolved = comfy.model_management.resolve_gpu_device_option(device)
if resolved is not None:
if resolved.type == "cpu":
model_options["load_device"] = model_options["offload_device"] = resolved
else:
model_options["load_device"] = resolved
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)
class LTXVAudioExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LTXVAudioVAELoader,
LTXVAudioVAEEncode,
LTXVAudioVAEDecode,
LTXVEmptyLatentAudio,
LTXAVTextEncoderLoader,
]
async def comfy_entrypoint() -> ComfyExtension:
return LTXVAudioExtension()