This commit is contained in:
R0CKSTAR 2026-01-07 18:56:56 +08:00 committed by GitHub
commit 19c2d83ecb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 35 deletions

View File

@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
else: else:
device = pos.device device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) if device.type == "musa":
omega = 1.0 / (theta**scale) # XXX (MUSA): Unsupported tensor dtype in Neg: Double
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)

View File

@ -139,6 +139,12 @@ try:
except: except:
ixuca_available = False ixuca_available = False
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
if args.cpu: if args.cpu:
cpu_state = CPUState.CPU cpu_state = CPUState.CPU
@ -146,27 +152,24 @@ def is_intel_xpu():
global cpu_state global cpu_state
global xpu_available global xpu_available
if cpu_state == CPUState.GPU: if cpu_state == CPUState.GPU:
if xpu_available: return xpu_available
return True
return False return False
def is_ascend_npu(): def is_ascend_npu():
global npu_available global npu_available
if npu_available: return npu_available
return True
return False
def is_mlu(): def is_mlu():
global mlu_available global mlu_available
if mlu_available: return mlu_available
return True
return False
def is_ixuca(): def is_ixuca():
global ixuca_available global ixuca_available
if ixuca_available: return ixuca_available
return True
return False def is_musa():
global musa_available
return musa_available
def get_torch_device(): def get_torch_device():
global directml_enabled global directml_enabled
@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False return False
MIN_WEIGHT_MEMORY_RATIO = 0.4 MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia(): if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0 MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False ENABLE_PYTORCH_ATTENTION = False
@ -320,7 +323,7 @@ if args.use_pytorch_cross_attention:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
try: try:
if is_nvidia(): if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2: if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
@ -375,7 +378,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try: try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.") logging.info("Enabled fp16 accumulation.")
@ -1020,7 +1023,7 @@ if args.async_offload is not None:
NUM_STREAMS = args.async_offload NUM_STREAMS = args.async_offload
else: else:
# Enable by default on Nvidia and AMD # Enable by default on Nvidia and AMD
if is_nvidia() or is_amd(): if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2 NUM_STREAMS = 2
if args.disable_async_offload: if args.disable_async_offload:
@ -1117,7 +1120,7 @@ PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0 TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1 MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory: if not args.disable_pinned_memory:
if is_nvidia() or is_amd(): if is_nvidia() or is_amd() or is_musa():
if WINDOWS: if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else: else:
@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca(): if is_ixuca():
return True return True
if is_musa():
return True
return False return False
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip: if torch.version.hip:
return True return True
if is_musa():
return True
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 8: if props.major >= 8:
return True return True
@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True return True
return False return False
if is_musa():
return True
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if is_mlu(): if is_mlu():
@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS: if SUPPORT_FP8_OPS:
return True return True
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device) props = torch.cuda.get_device_properties(device)
if props.major >= 9: if is_nvidia():
return True if props.major >= 9:
if props.major < 8: return True
return False if props.major < 8:
if props.minor < 9: return False
return False if props.minor < 9:
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False return False
return True if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False
elif is_musa():
if props.major >= 3:
return True
return False
def supports_nvfp4_compute(device=None): def supports_nvfp4_compute(device=None):
if not is_nvidia(): if not is_nvidia():
@ -1553,7 +1566,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
def debug_memory_summary(): def debug_memory_summary():
if is_amd() or is_nvidia(): if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary() return torch.cuda.memory.memory_summary()
return "" return ""

View File

@ -28,3 +28,4 @@ kornia>=0.7.1
spandrel spandrel
pydantic~=2.0 pydantic~=2.0
pydantic-settings~=2.0 pydantic-settings~=2.0
torchada>=0.1.11