ComfyUI/comfy/audio_encoders/audio_encoders.py
Emiliooooo e860732dba fix(directml): correct VRAM detection and make torchaudio imports optional
## VRAM Detection (model_management.py)

The DirectML code path had two hardcoded `1024 * 1024 * 1024 #TODO` values
in `get_total_memory()` and `get_free_memory()`, causing ComfyUI to report
only 1 GB of VRAM on any AMD/Intel GPU using the DirectML backend — regardless
of actual hardware. This forced NORMAL_VRAM or LOW_VRAM calculations to be
wildly wrong.

Fix for `get_total_memory`:
- On Windows, reads `HardwareInformation.qwMemorySize` from the GPU driver
  registry key via `winreg`. This is the 64-bit accurate value (unlike
  `Win32_VideoController.AdapterRAM` which overflows at 4 GB).
- Allows override via `COMFYUI_DIRECTML_VRAM_MB` env var.
- Falls back to 6 GB if registry query fails (safe default for modern dGPUs).

Fix for `get_free_memory`:
- Uses `torch_directml.gpu_memory(0)` to get per-tile usage fractions and
  derives free memory as `total * (1 - max_usage_fraction)`.

## torchaudio: optional import on AMD/DirectML

torchaudio has a DLL incompatibility with torch-directml (which ships its own
torch runtime). The following files had bare `import torchaudio` at module
level, crashing ComfyUI startup entirely when torchaudio was absent:

- comfy/ldm/lightricks/vae/audio_vae.py
- comfy/audio_encoders/whisper.py
- comfy/audio_encoders/audio_encoders.py
- comfy_extras/nodes_audio.py
- comfy_extras/nodes_lt.py
- comfy_extras/nodes_wandancer.py

Each import is wrapped in `try/except (ImportError, OSError): torchaudio = None`,
matching the pattern already used in comfy/ldm/mmaudio/vae/autoencoder.py and
comfy/ldm/ace/vae/music_dcae_pipeline.py. Audio nodes will degrade gracefully
rather than preventing ComfyUI from starting.

Tested on: AMD Radeon RX 5600 XT (6 GB VRAM, gfx1010, Windows 10)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-14 12:10:31 -04:00

96 lines
3.5 KiB
Python

from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
import logging
try:
import torchaudio
except (ImportError, OSError):
torchaudio = None
class AudioEncoderModel():
def __init__(self, config):
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
model_type = config.pop("model_type")
model_config = dict(config)
model_config.update({
"dtype": self.dtype,
"device": offload_device,
"operations": comfy.ops.manual_cast
})
if model_type == "wav2vec2":
self.model = Wav2Vec2Model(**model_config)
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
comfy.model_management.archive_model_dtypes(self.model)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_audio(self, audio, sample_rate):
comfy.model_management.load_model_gpu(self.patcher)
audio = torchaudio.functional.resample(audio, sample_rate, self.model_sample_rate)
out, all_layers = self.model(audio.to(self.load_device))
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
outputs["audio_samples"] = audio.shape[2]
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
if "encoder.layer_norm.bias" in sd: #wav2vec2
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"model_type": "wav2vec2",
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
}
elif embed_dim == 768: # base
config = {
"model_type": "wav2vec2",
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
}
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
config = {
"model_type": "whisper3",
}
else:
raise RuntimeError("ERROR: audio encoder not supported.")
audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))
return audio_encoder