mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-17 04:27:27 +08:00
## VRAM Detection (model_management.py) The DirectML code path had two hardcoded `1024 * 1024 * 1024 #TODO` values in `get_total_memory()` and `get_free_memory()`, causing ComfyUI to report only 1 GB of VRAM on any AMD/Intel GPU using the DirectML backend — regardless of actual hardware. This forced NORMAL_VRAM or LOW_VRAM calculations to be wildly wrong. Fix for `get_total_memory`: - On Windows, reads `HardwareInformation.qwMemorySize` from the GPU driver registry key via `winreg`. This is the 64-bit accurate value (unlike `Win32_VideoController.AdapterRAM` which overflows at 4 GB). - Allows override via `COMFYUI_DIRECTML_VRAM_MB` env var. - Falls back to 6 GB if registry query fails (safe default for modern dGPUs). Fix for `get_free_memory`: - Uses `torch_directml.gpu_memory(0)` to get per-tile usage fractions and derives free memory as `total * (1 - max_usage_fraction)`. ## torchaudio: optional import on AMD/DirectML torchaudio has a DLL incompatibility with torch-directml (which ships its own torch runtime). The following files had bare `import torchaudio` at module level, crashing ComfyUI startup entirely when torchaudio was absent: - comfy/ldm/lightricks/vae/audio_vae.py - comfy/audio_encoders/whisper.py - comfy/audio_encoders/audio_encoders.py - comfy_extras/nodes_audio.py - comfy_extras/nodes_lt.py - comfy_extras/nodes_wandancer.py Each import is wrapped in `try/except (ImportError, OSError): torchaudio = None`, matching the pattern already used in comfy/ldm/mmaudio/vae/autoencoder.py and comfy/ldm/ace/vae/music_dcae_pipeline.py. Audio nodes will degrade gracefully rather than preventing ComfyUI from starting. Tested on: AMD Radeon RX 5600 XT (6 GB VRAM, gfx1010, Windows 10) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
190 lines
5.9 KiB
Python
Executable File
190 lines
5.9 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
try:
|
|
import torchaudio
|
|
except (ImportError, OSError):
|
|
torchaudio = None
|
|
from typing import Optional
|
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
|
import comfy.ops
|
|
|
|
class WhisperFeatureExtractor(nn.Module):
|
|
def __init__(self, n_mels=128, device=None):
|
|
super().__init__()
|
|
self.sample_rate = 16000
|
|
self.n_fft = 400
|
|
self.hop_length = 160
|
|
self.n_mels = n_mels
|
|
self.chunk_length = 30
|
|
self.n_samples = 480000
|
|
|
|
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
|
|
sample_rate=self.sample_rate,
|
|
n_fft=self.n_fft,
|
|
hop_length=self.hop_length,
|
|
n_mels=self.n_mels,
|
|
f_min=0,
|
|
f_max=8000,
|
|
norm="slaney",
|
|
mel_scale="slaney",
|
|
).to(device)
|
|
|
|
def __call__(self, audio):
|
|
audio = torch.mean(audio, dim=1)
|
|
batch_size = audio.shape[0]
|
|
processed_audio = []
|
|
|
|
for i in range(batch_size):
|
|
aud = audio[i]
|
|
if aud.shape[0] > self.n_samples:
|
|
aud = aud[:self.n_samples]
|
|
elif aud.shape[0] < self.n_samples:
|
|
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
|
|
processed_audio.append(aud)
|
|
|
|
audio = torch.stack(processed_audio)
|
|
|
|
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
|
|
|
|
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
|
|
log_mel_spec = (log_mel_spec + 4.0) / 4.0
|
|
|
|
return log_mel_spec
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
assert d_model % n_heads == 0
|
|
|
|
self.d_model = d_model
|
|
self.n_heads = n_heads
|
|
self.d_k = d_model // n_heads
|
|
|
|
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
|
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
|
|
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
|
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
batch_size, seq_len, _ = query.shape
|
|
|
|
q = self.q_proj(query)
|
|
k = self.k_proj(key)
|
|
v = self.v_proj(value)
|
|
|
|
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
return attn_output
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
|
|
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
|
|
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
|
|
|
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
|
|
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
|
|
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
residual = x
|
|
x = self.self_attn_layer_norm(x)
|
|
x = self.self_attn(x, x, x, attention_mask)
|
|
x = residual + x
|
|
|
|
residual = x
|
|
x = self.final_layer_norm(x)
|
|
x = self.fc1(x)
|
|
x = F.gelu(x)
|
|
x = self.fc2(x)
|
|
x = residual + x
|
|
|
|
return x
|
|
|
|
|
|
class AudioEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_mels: int = 128,
|
|
n_ctx: int = 1500,
|
|
n_state: int = 1280,
|
|
n_head: int = 20,
|
|
n_layer: int = 32,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
|
|
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
|
|
|
|
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
|
|
|
|
self.layers = nn.ModuleList([
|
|
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(n_layer)
|
|
])
|
|
|
|
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = F.gelu(self.conv1(x))
|
|
x = F.gelu(self.conv2(x))
|
|
|
|
x = x.transpose(1, 2)
|
|
|
|
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
|
|
|
all_x = ()
|
|
for layer in self.layers:
|
|
all_x += (x,)
|
|
x = layer(x)
|
|
|
|
x = self.layer_norm(x)
|
|
all_x += (x,)
|
|
return x, all_x
|
|
|
|
|
|
class WhisperLargeV3(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_mels: int = 128,
|
|
n_audio_ctx: int = 1500,
|
|
n_audio_state: int = 1280,
|
|
n_audio_head: int = 20,
|
|
n_audio_layer: int = 32,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
|
|
|
|
self.encoder = AudioEncoder(
|
|
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
|
|
dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
def forward(self, audio):
|
|
mel = self.feature_extractor(audio)
|
|
x, all_x = self.encoder(mel)
|
|
return x, all_x
|