From e860732dba381576bfe7dd0f97e142229ae7ff6d Mon Sep 17 00:00:00 2001 From: Emiliooooo Date: Thu, 14 May 2026 12:10:31 -0400 Subject: [PATCH] fix(directml): correct VRAM detection and make torchaudio imports optional MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## VRAM Detection (model_management.py) The DirectML code path had two hardcoded `1024 * 1024 * 1024 #TODO` values in `get_total_memory()` and `get_free_memory()`, causing ComfyUI to report only 1 GB of VRAM on any AMD/Intel GPU using the DirectML backend — regardless of actual hardware. This forced NORMAL_VRAM or LOW_VRAM calculations to be wildly wrong. Fix for `get_total_memory`: - On Windows, reads `HardwareInformation.qwMemorySize` from the GPU driver registry key via `winreg`. This is the 64-bit accurate value (unlike `Win32_VideoController.AdapterRAM` which overflows at 4 GB). - Allows override via `COMFYUI_DIRECTML_VRAM_MB` env var. - Falls back to 6 GB if registry query fails (safe default for modern dGPUs). Fix for `get_free_memory`: - Uses `torch_directml.gpu_memory(0)` to get per-tile usage fractions and derives free memory as `total * (1 - max_usage_fraction)`. ## torchaudio: optional import on AMD/DirectML torchaudio has a DLL incompatibility with torch-directml (which ships its own torch runtime). The following files had bare `import torchaudio` at module level, crashing ComfyUI startup entirely when torchaudio was absent: - comfy/ldm/lightricks/vae/audio_vae.py - comfy/audio_encoders/whisper.py - comfy/audio_encoders/audio_encoders.py - comfy_extras/nodes_audio.py - comfy_extras/nodes_lt.py - comfy_extras/nodes_wandancer.py Each import is wrapped in `try/except (ImportError, OSError): torchaudio = None`, matching the pattern already used in comfy/ldm/mmaudio/vae/autoencoder.py and comfy/ldm/ace/vae/music_dcae_pipeline.py. Audio nodes will degrade gracefully rather than preventing ComfyUI from starting. Tested on: AMD Radeon RX 5600 XT (6 GB VRAM, gfx1010, Windows 10) Co-Authored-By: Claude Sonnet 4.6 --- comfy/audio_encoders/audio_encoders.py | 5 ++- comfy/audio_encoders/whisper.py | 5 ++- comfy/ldm/lightricks/vae/audio_vae.py | 5 ++- comfy/model_management.py | 48 ++++++++++++++++++++++++-- comfy_extras/nodes_audio.py | 5 ++- comfy_extras/nodes_lt.py | 5 ++- comfy_extras/nodes_wandancer.py | 5 ++- 7 files changed, 69 insertions(+), 9 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 0de7584b0..5413a7db3 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -4,7 +4,10 @@ import comfy.model_management import comfy.ops import comfy.utils import logging -import torchaudio +try: + import torchaudio +except (ImportError, OSError): + torchaudio = None class AudioEncoderModel(): diff --git a/comfy/audio_encoders/whisper.py b/comfy/audio_encoders/whisper.py index 93d3782f1..f4f5c4655 100755 --- a/comfy/audio_encoders/whisper.py +++ b/comfy/audio_encoders/whisper.py @@ -1,7 +1,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchaudio +try: + import torchaudio +except (ImportError, OSError): + torchaudio = None from typing import Optional from comfy.ldm.modules.attention import optimized_attention_masked import comfy.ops diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index dd5320c8f..6755f5ff6 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -2,7 +2,10 @@ import json from dataclasses import dataclass import math import torch -import torchaudio +try: + import torchaudio +except (ImportError, OSError): + torchaudio = None from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier diff --git a/comfy/model_management.py b/comfy/model_management.py index 21738a4c7..6b4d4b770 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -101,7 +101,7 @@ if args.deterministic: directml_enabled = False if args.directml is not None: - logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.") + logging.info("DirectML backend active (AMD/Intel GPU on Windows, no CUDA/ROCm required).") import torch_directml directml_enabled = True device_index = args.directml @@ -213,7 +213,40 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total_torch = mem_total else: if directml_enabled: - mem_total = 1024 * 1024 * 1024 #TODO + # Query real VRAM from Windows registry (qwMemorySize is 64-bit, AdapterRAM caps at 4GB) + # Falls back to COMFYUI_DIRECTML_VRAM_MB env var, then 6GB default + _dml_vram = 0 + try: + _override = os.environ.get("COMFYUI_DIRECTML_VRAM_MB") + if _override: + _dml_vram = int(_override) * 1024 * 1024 + except Exception: + pass + if _dml_vram <= 0: + try: + import winreg as _winreg + _base = r"SYSTEM\CurrentControlSet\Control\Class\{4d36e968-e325-11ce-bfc1-08002be10318}" + with _winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE, _base) as _hbase: + _i = 0 + while True: + try: + _sub = _winreg.EnumKey(_hbase, _i) + _i += 1 + try: + with _winreg.OpenKey(_hbase, _sub) as _hdev: + _mem, _ = _winreg.QueryValueEx(_hdev, "HardwareInformation.qwMemorySize") + if isinstance(_mem, int) and _mem > 128 * 1024 * 1024: + _dml_vram = _mem + break + except Exception: + pass + except OSError: + break + except Exception: + pass + if _dml_vram <= 0: + _dml_vram = 6 * 1024 * 1024 * 1024 # 6GB safe default for modern AMD cards + mem_total = _dml_vram mem_total_torch = mem_total elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) @@ -1504,7 +1537,16 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_torch = mem_free_total else: if directml_enabled: - mem_free_total = 1024 * 1024 * 1024 #TODO + # gpu_memory(0) returns a list of per-tile usage fractions [0.0–1.0] + # total_vram (module-level) is the registry-queried real VRAM in MB + try: + import torch_directml as _tdml + _usage_fracs = _tdml.gpu_memory(0) + _usage_pct = max(_usage_fracs) if _usage_fracs else 0.0 + _total = int(total_vram * 1024 * 1024) + mem_free_total = max(0, int(_total * (1.0 - _usage_pct))) + except Exception: + mem_free_total = int(total_vram * 1024 * 1024) mem_free_torch = mem_free_total elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index fcc1c34d5..ce1a49cd1 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -1,7 +1,10 @@ from __future__ import annotations import av -import torchaudio +try: + import torchaudio +except (ImportError, OSError): + torchaudio = None import torch import comfy.model_management import folder_paths diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 3dc1199c2..48137fdf6 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,7 +1,10 @@ import nodes import node_helpers import torch -import torchaudio +try: + import torchaudio +except (ImportError, OSError): + torchaudio = None import comfy.model_management import comfy.model_sampling import comfy.samplers diff --git a/comfy_extras/nodes_wandancer.py b/comfy_extras/nodes_wandancer.py index fc005ed4c..dbc929c83 100644 --- a/comfy_extras/nodes_wandancer.py +++ b/comfy_extras/nodes_wandancer.py @@ -2,7 +2,10 @@ import math import nodes import node_helpers import torch -import torchaudio +try: + import torchaudio +except (ImportError, OSError): + torchaudio = None import comfy.model_management import comfy.utils import numpy as np