From 93510fde17786759bdc06845e5eb1195fca6eae8 Mon Sep 17 00:00:00 2001 From: Emiliooooo Date: Thu, 14 May 2026 19:20:09 -0400 Subject: [PATCH] fix(directml): guard opaque tensor storage and zero VRAM edge cases Two runtime crashes affect AMD GPU users on Windows using torch-directml: 1. NotImplementedError in module_mmap_residency / cast_to_gathered DirectML tensors are opaque (OpaqueTensorImpl) and do not support untyped_storage(). Wrap both call sites in try/except so mmap tracking is skipped for DirectML tensors instead of crashing. 2. ZeroDivisionError in attention_split DirectML does not expose free VRAM via the standard query path, leaving mem_free_total as 0. Guard the math.log() call with a floor of 4 GB so split-attention steps are computed safely. Tested on AMD RX 5600 XT (6 GB VRAM), Windows 11, torch-directml 0.2.5, ComfyUI 0.21.1, DreamShaper 8 (SD 1.5). Co-Authored-By: Claude Sonnet 4.5 --- comfy/ldm/modules/attention.py | 4 ++++ comfy/model_management.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a68cb8439..37b2a8ee3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -336,6 +336,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape steps = 1 + if mem_free_total <= 0: + # DirectML doesn't expose free VRAM — assume 4GB free as a safe fallback for 6GB cards + mem_free_total = 4 * (1024 ** 3) + if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " diff --git a/comfy/model_management.py b/comfy/model_management.py index 6b4d4b770..a14627842 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -544,7 +544,11 @@ def module_mmap_residency(module, free=False): for k in sd: t = sd[k] module_mem += t.nbytes - storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() + try: + storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() + except NotImplementedError: + # DirectML (AMD) tensors are opaque — no host storage to inspect; skip mmap tracking + continue if not getattr(storage, "_comfy_tensor_mmap_touched", False): continue mmap_touched_mem += t.nbytes @@ -1328,7 +1332,12 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): continue if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): continue - storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() + try: + storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() + except NotImplementedError: + # DirectML tensors are opaque — skip mmap marking, just copy + dest_view.copy_(tensor, non_blocking=non_blocking) + continue if hasattr(storage, "_comfy_tensor_mmap_touched"): storage._comfy_tensor_mmap_touched = True dest_view.copy_(tensor, non_blocking=non_blocking)