This commit is contained in:
R0CKSTAR 2026-01-27 08:54:00 +08:00 committed by GitHub
commit f96c2c4131
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 33 deletions

View File

@ -138,6 +138,12 @@ try:
except: except:
ixuca_available = False ixuca_available = False
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -145,27 +151,24 @@ def is_intel_xpu():
global cpu_state global cpu_state
global xpu_available global xpu_available
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
if xpu_available: return xpu_available
return True
return False return False
def is_ascend_npu(): def is_ascend_npu():
global npu_available global npu_available
if npu_available: return npu_available
return True
return False
def is_mlu(): def is_mlu():
global mlu_available global mlu_available
if mlu_available: return mlu_available
return True
return False
def is_ixuca(): def is_ixuca():
global ixuca_available global ixuca_available
if ixuca_available: return ixuca_available
return True
return False def is_musa():
global musa_available
return musa_available
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
@ -310,7 +313,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False return False
MIN_WEIGHT_MEMORY_RATIO = 0.4 MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia(): if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0 MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
@ -319,7 +322,7 @@ if args.use_pytorch_cross_attention:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
try: try:
if is_nvidia(): if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -386,7 +389,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try: try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.") logging.info("Enabled fp16 accumulation.")
@ -1031,7 +1034,7 @@ if args.async_offload is not None:
NUM_STREAMS = args.async_offload NUM_STREAMS = args.async_offload
else: else:
# Enable by default on Nvidia and AMD # Enable by default on Nvidia and AMD
if is_nvidia() or is_amd(): if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2 NUM_STREAMS = 2
if args.disable_async_offload: if args.disable_async_offload:
@ -1128,7 +1131,7 @@ PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0 TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1 MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory: if not args.disable_pinned_memory:
if is_nvidia() or is_amd(): if is_nvidia() or is_amd() or is_musa():
if WINDOWS: if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else: else:
@ -1272,6 +1275,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca(): if is_ixuca():
return True return True
if is_musa():
return True
return False return False
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
@ -1403,6 +1408,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip: if torch.version.hip:
return True return True
if is_musa():
return True
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 8: if props.major >= 8:
return True return True
@ -1473,6 +1481,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True return True
return False return False
if is_musa():
return True
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if is_mlu(): if is_mlu():
@ -1495,25 +1506,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS: if SUPPORT_FP8_OPS:
return True return True
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 9: if is_nvidia():
return True if props.major >= 9:
if props.major < 8: return True
return False if props.major < 8:
if props.minor < 9: return False
return False if props.minor < 9:
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False return False
return True if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False
elif is_musa():
if props.major >= 3:
return True
return False
def supports_nvfp4_compute(device=None): def supports_nvfp4_compute(device=None):
if not is_nvidia(): if not is_nvidia():
@ -1564,7 +1577,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
def debug_memory_summary(): def debug_memory_summary():
if is_amd() or is_nvidia(): if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary() return torch.cuda.memory.memory_summary()
return "" return ""

View File

@ -37,11 +37,13 @@ if __name__ == "__main__":
devices = ','.join(map(str, devices)) devices = ','.join(map(str, devices))
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices) os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
os.environ['HIP_VISIBLE_DEVICES'] = str(devices) os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
os.environ['MUSA_VISIBLE_DEVICES'] = str(devices)
if args.cuda_device is not None: if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device) os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
os.environ['MUSA_VISIBLE_DEVICES'] = str(args.cuda_device)
logging.info("Set cuda device to: {}".format(args.cuda_device)) logging.info("Set cuda device to: {}".format(args.cuda_device))
if args.oneapi_device_selector is not None: if args.oneapi_device_selector is not None:

View File

@ -29,3 +29,4 @@ kornia>=0.7.1
spandrel spandrel
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0
torchada>=0.1.22