fix(directml): guard opaque tensor storage and zero VRAM edge cases

Two runtime crashes affect AMD GPU users on Windows using torch-directml:

1. NotImplementedError in module_mmap_residency / cast_to_gathered
   DirectML tensors are opaque (OpaqueTensorImpl) and do not support
   untyped_storage(). Wrap both call sites in try/except so mmap
   tracking is skipped for DirectML tensors instead of crashing.

2. ZeroDivisionError in attention_split
   DirectML does not expose free VRAM via the standard query path,
   leaving mem_free_total as 0. Guard the math.log() call with a
   floor of 4 GB so split-attention steps are computed safely.

Tested on AMD RX 5600 XT (6 GB VRAM), Windows 11, torch-directml 0.2.5,
ComfyUI 0.21.1, DreamShaper 8 (SD 1.5).

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Emiliooooo 2026-05-14 19:20:09 -04:00
parent e860732dba
commit 93510fde17
2 changed files with 15 additions and 2 deletions

View File

@ -336,6 +336,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
steps = 1
if mem_free_total <= 0:
# DirectML doesn't expose free VRAM — assume 4GB free as a safe fallback for 6GB cards
mem_free_total = 4 * (1024 ** 3)
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "

View File

@ -544,7 +544,11 @@ def module_mmap_residency(module, free=False):
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
try:
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
except NotImplementedError:
# DirectML (AMD) tensors are opaque — no host storage to inspect; skip mmap tracking
continue
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
@ -1328,7 +1332,12 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
try:
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
except NotImplementedError:
# DirectML tensors are opaque — skip mmap marking, just copy
dest_view.copy_(tensor, non_blocking=non_blocking)
continue
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking)