This commit is contained in:
R0CKSTAR 2026-01-06 17:05:47 +08:00 committed by GitHub
commit f3dad735bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 35 deletions

View File

@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
if device.type == "musa":
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
if not isinstance(theta, torch.Tensor):
theta = torch.tensor(theta, dtype=torch.float32, device=device)
omega = torch.exp(-scale * torch.log(theta + 1e-6))
else:
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)

View File

@ -139,6 +139,12 @@ try:
except:
ixuca_available = False
try:
import torchada # noqa: F401
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
except:
musa_available = False
if args.cpu:
cpu_state = CPUState.CPU
@ -146,27 +152,24 @@ def is_intel_xpu():
global cpu_state
global xpu_available
if cpu_state == CPUState.GPU:
if xpu_available:
return True
return xpu_available
return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
return npu_available
def is_mlu():
global mlu_available
if mlu_available:
return True
return False
return mlu_available
def is_ixuca():
global ixuca_available
if ixuca_available:
return True
return False
return ixuca_available
def is_musa():
global musa_available
return musa_available
def get_torch_device():
global directml_enabled
@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0):
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
if is_nvidia() or is_musa():
MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
@ -320,7 +323,7 @@ if args.use_pytorch_cross_attention:
XFORMERS_IS_AVAILABLE = False
try:
if is_nvidia():
if is_nvidia() or is_musa():
if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@ -375,7 +378,7 @@ if ENABLE_PYTORCH_ATTENTION:
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
torch.backends.cuda.matmul.allow_fp16_accumulation = True
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
logging.info("Enabled fp16 accumulation.")
@ -1020,7 +1023,7 @@ if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia and AMD
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
NUM_STREAMS = 2
if args.disable_async_offload:
@ -1117,7 +1120,7 @@ PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if is_nvidia() or is_amd() or is_musa():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention():
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
if is_ixuca():
return True
if is_musa():
return True
return False
def force_upcast_attention_dtype():
@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if torch.version.hip:
return True
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return True
return False
if is_musa():
return True
props = torch.cuda.get_device_properties(device)
if is_mlu():
@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None):
if SUPPORT_FP8_OPS:
return True
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
if is_nvidia():
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
if torch_version_numeric < (2, 3):
return False
if WINDOWS:
if torch_version_numeric < (2, 4):
return False
elif is_musa():
if props.major >= 3:
return True
return False
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
@ -1543,7 +1556,7 @@ def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
if is_amd() or is_nvidia() or is_musa():
return torch.cuda.memory.memory_summary()
return ""

View File

@ -28,3 +28,4 @@ kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
torchada>=0.1.11