From f0caa15a17d8fb87d9fa7aada40c017060d2582e Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Sun, 4 Jan 2026 17:55:04 +0800 Subject: [PATCH] Support MThreads (MUSA) GPU Signed-off-by: Xiaodong Ye --- comfy/ldm/flux/math.py | 11 +++++- comfy/model_management.py | 79 +++++++++++++++++++++++---------------- requirements.txt | 1 + 3 files changed, 56 insertions(+), 35 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..615b4b357 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: else: device = pos.device - scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) - omega = 1.0 / (theta**scale) + if device.type == "musa": + # XXX (MUSA): Unsupported tensor dtype in Neg: Double + scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device) + if not isinstance(theta, torch.Tensor): + theta = torch.tensor(theta, dtype=torch.float32, device=device) + omega = torch.exp(-scale * torch.log(theta + 1e-6)) + else: + scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) + omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2501cecb7..8b6529c95 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -139,6 +139,12 @@ try: except: ixuca_available = False +try: + import torchada # noqa: F401 + musa_available = hasattr(torch, "musa") and torch.musa.is_available() +except: + musa_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -146,27 +152,24 @@ def is_intel_xpu(): global cpu_state global xpu_available if cpu_state == CPUState.GPU: - if xpu_available: - return True + return xpu_available return False def is_ascend_npu(): global npu_available - if npu_available: - return True - return False + return npu_available def is_mlu(): global mlu_available - if mlu_available: - return True - return False + return mlu_available def is_ixuca(): global ixuca_available - if ixuca_available: - return True - return False + return ixuca_available + +def is_musa(): + global musa_available + return musa_available def get_torch_device(): global directml_enabled @@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0): return False MIN_WEIGHT_MEMORY_RATIO = 0.4 -if is_nvidia(): +if is_nvidia() or is_musa(): MIN_WEIGHT_MEMORY_RATIO = 0.0 ENABLE_PYTORCH_ATTENTION = False @@ -320,7 +323,7 @@ if args.use_pytorch_cross_attention: XFORMERS_IS_AVAILABLE = False try: - if is_nvidia(): + if is_nvidia() or is_musa(): if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True @@ -375,7 +378,7 @@ if ENABLE_PYTORCH_ATTENTION: PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: - if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: + if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast: torch.backends.cuda.matmul.allow_fp16_accumulation = True PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance logging.info("Enabled fp16 accumulation.") @@ -1020,7 +1023,7 @@ if args.async_offload is not None: NUM_STREAMS = args.async_offload else: # Enable by default on Nvidia and AMD - if is_nvidia() or is_amd(): + if is_nvidia() or is_amd() or is_musa(): NUM_STREAMS = 2 if args.disable_async_offload: @@ -1117,7 +1120,7 @@ PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: - if is_nvidia() or is_amd(): + if is_nvidia() or is_amd() or is_musa(): if WINDOWS: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% else: @@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention if is_ixuca(): return True + if is_musa(): + return True return False def force_upcast_attention_dtype(): @@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if torch.version.hip: return True + if is_musa(): + return True + props = torch.cuda.get_device_properties(device) if props.major >= 8: return True @@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return True return False + if is_musa(): + return True + props = torch.cuda.get_device_properties(device) if is_mlu(): @@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None): if SUPPORT_FP8_OPS: return True - if not is_nvidia(): - return False - props = torch.cuda.get_device_properties(device) - if props.major >= 9: - return True - if props.major < 8: - return False - if props.minor < 9: - return False - - if torch_version_numeric < (2, 3): - return False - - if WINDOWS: - if torch_version_numeric < (2, 4): + if is_nvidia(): + if props.major >= 9: + return True + if props.major < 8: + return False + if props.minor < 9: return False - return True + if torch_version_numeric < (2, 3): + return False + + if WINDOWS: + if torch_version_numeric < (2, 4): + return False + + elif is_musa(): + if props.major >= 3: + return True + + return False def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older @@ -1543,7 +1556,7 @@ def unload_all_models(): free_memory(1e30, get_torch_device()) def debug_memory_summary(): - if is_amd() or is_nvidia(): + if is_amd() or is_nvidia() or is_musa(): return torch.cuda.memory.memory_summary() return "" diff --git a/requirements.txt b/requirements.txt index 3a05799eb..42e1c3709 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,4 @@ kornia>=0.7.1 spandrel pydantic~=2.0 pydantic-settings~=2.0 +torchada>=0.1.11