From e5914409d32f94ea65474b33615695c56a090bf9 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Mon, 26 Jan 2026 10:36:37 +0800 Subject: [PATCH] Support MThreads (MUSA) GPU Signed-off-by: Xiaodong Ye --- comfy/model_management.py | 79 +++++++++++++++++++++++---------------- main.py | 2 + requirements.txt | 1 + 3 files changed, 49 insertions(+), 33 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..2a9acd788 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -138,6 +138,12 @@ try: except: ixuca_available = False +try: + import torchada # noqa: F401 + musa_available = hasattr(torch, "musa") and torch.musa.is_available() +except: + musa_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -145,27 +151,24 @@ def is_intel_xpu(): global cpu_state global xpu_available if cpu_state == CPUState.GPU: - if xpu_available: - return True + return xpu_available return False def is_ascend_npu(): global npu_available - if npu_available: - return True - return False + return npu_available def is_mlu(): global mlu_available - if mlu_available: - return True - return False + return mlu_available def is_ixuca(): global ixuca_available - if ixuca_available: - return True - return False + return ixuca_available + +def is_musa(): + global musa_available + return musa_available def get_torch_device(): global directml_enabled @@ -310,7 +313,7 @@ def amd_min_version(device=None, min_rdna_version=0): return False MIN_WEIGHT_MEMORY_RATIO = 0.4 -if is_nvidia(): +if is_nvidia() or is_musa(): MIN_WEIGHT_MEMORY_RATIO = 0.0 ENABLE_PYTORCH_ATTENTION = False @@ -319,7 +322,7 @@ if args.use_pytorch_cross_attention: XFORMERS_IS_AVAILABLE = False try: - if is_nvidia(): + if is_nvidia() or is_musa(): if torch_version_numeric[0] >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True @@ -386,7 +389,7 @@ if ENABLE_PYTORCH_ATTENTION: PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: - if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: + if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast: torch.backends.cuda.matmul.allow_fp16_accumulation = True PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance logging.info("Enabled fp16 accumulation.") @@ -1031,7 +1034,7 @@ if args.async_offload is not None: NUM_STREAMS = args.async_offload else: # Enable by default on Nvidia and AMD - if is_nvidia() or is_amd(): + if is_nvidia() or is_amd() or is_musa(): NUM_STREAMS = 2 if args.disable_async_offload: @@ -1128,7 +1131,7 @@ PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: - if is_nvidia() or is_amd(): + if is_nvidia() or is_amd() or is_musa(): if WINDOWS: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% else: @@ -1272,6 +1275,8 @@ def pytorch_attention_flash_attention(): return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention if is_ixuca(): return True + if is_musa(): + return True return False def force_upcast_attention_dtype(): @@ -1403,6 +1408,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if torch.version.hip: return True + if is_musa(): + return True + props = torch.cuda.get_device_properties(device) if props.major >= 8: return True @@ -1473,6 +1481,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return True return False + if is_musa(): + return True + props = torch.cuda.get_device_properties(device) if is_mlu(): @@ -1495,25 +1506,27 @@ def supports_fp8_compute(device=None): if SUPPORT_FP8_OPS: return True - if not is_nvidia(): - return False - props = torch.cuda.get_device_properties(device) - if props.major >= 9: - return True - if props.major < 8: - return False - if props.minor < 9: - return False - - if torch_version_numeric < (2, 3): - return False - - if WINDOWS: - if torch_version_numeric < (2, 4): + if is_nvidia(): + if props.major >= 9: + return True + if props.major < 8: + return False + if props.minor < 9: return False - return True + if torch_version_numeric < (2, 3): + return False + + if WINDOWS: + if torch_version_numeric < (2, 4): + return False + + elif is_musa(): + if props.major >= 3: + return True + + return False def supports_nvfp4_compute(device=None): if not is_nvidia(): @@ -1564,7 +1577,7 @@ def unload_all_models(): free_memory(1e30, get_torch_device()) def debug_memory_summary(): - if is_amd() or is_nvidia(): + if is_amd() or is_nvidia() or is_musa(): return torch.cuda.memory.memory_summary() return "" diff --git a/main.py b/main.py index 37b06c1fa..f5139a3b9 100644 --- a/main.py +++ b/main.py @@ -37,11 +37,13 @@ if __name__ == "__main__": devices = ','.join(map(str, devices)) os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) os.environ['HIP_VISIBLE_DEVICES'] = str(devices) + os.environ['MUSA_VISIBLE_DEVICES'] = str(devices) if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) + os.environ['MUSA_VISIBLE_DEVICES'] = str(args.cuda_device) logging.info("Set cuda device to: {}".format(args.cuda_device)) if args.oneapi_device_selector is not None: diff --git a/requirements.txt b/requirements.txt index ec89dccd2..5003a0bc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ kornia>=0.7.1 spandrel pydantic~=2.0 pydantic-settings~=2.0 +torchada>=0.1.22