mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
Merge f0caa15a17 into e14f3b6610
This commit is contained in:
commit
f3dad735bf
@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
else:
|
||||
device = pos.device
|
||||
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
if device.type == "musa":
|
||||
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
|
||||
if not isinstance(theta, torch.Tensor):
|
||||
theta = torch.tensor(theta, dtype=torch.float32, device=device)
|
||||
omega = torch.exp(-scale * torch.log(theta + 1e-6))
|
||||
else:
|
||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
|
||||
@ -139,6 +139,12 @@ try:
|
||||
except:
|
||||
ixuca_available = False
|
||||
|
||||
try:
|
||||
import torchada # noqa: F401
|
||||
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
|
||||
except:
|
||||
musa_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@ -146,27 +152,24 @@ def is_intel_xpu():
|
||||
global cpu_state
|
||||
global xpu_available
|
||||
if cpu_state == CPUState.GPU:
|
||||
if xpu_available:
|
||||
return True
|
||||
return xpu_available
|
||||
return False
|
||||
|
||||
def is_ascend_npu():
|
||||
global npu_available
|
||||
if npu_available:
|
||||
return True
|
||||
return False
|
||||
return npu_available
|
||||
|
||||
def is_mlu():
|
||||
global mlu_available
|
||||
if mlu_available:
|
||||
return True
|
||||
return False
|
||||
return mlu_available
|
||||
|
||||
def is_ixuca():
|
||||
global ixuca_available
|
||||
if ixuca_available:
|
||||
return True
|
||||
return False
|
||||
return ixuca_available
|
||||
|
||||
def is_musa():
|
||||
global musa_available
|
||||
return musa_available
|
||||
|
||||
def get_torch_device():
|
||||
global directml_enabled
|
||||
@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0):
|
||||
return False
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
if is_nvidia() or is_musa():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
@ -320,7 +323,7 @@ if args.use_pytorch_cross_attention:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
if is_nvidia() or is_musa():
|
||||
if torch_version_numeric[0] >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -375,7 +378,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
|
||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||
try:
|
||||
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||
logging.info("Enabled fp16 accumulation.")
|
||||
@ -1020,7 +1023,7 @@ if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia and AMD
|
||||
if is_nvidia() or is_amd():
|
||||
if is_nvidia() or is_amd() or is_musa():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
@ -1117,7 +1120,7 @@ PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if is_nvidia() or is_amd() or is_musa():
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||
else:
|
||||
@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention():
|
||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||
if is_ixuca():
|
||||
return True
|
||||
if is_musa():
|
||||
return True
|
||||
return False
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
if is_musa():
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return True
|
||||
return False
|
||||
|
||||
if is_musa():
|
||||
return True
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
|
||||
if is_mlu():
|
||||
@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None):
|
||||
if SUPPORT_FP8_OPS:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 9:
|
||||
return True
|
||||
if props.major < 8:
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
if torch_version_numeric < (2, 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if torch_version_numeric < (2, 4):
|
||||
if is_nvidia():
|
||||
if props.major >= 9:
|
||||
return True
|
||||
if props.major < 8:
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
|
||||
return True
|
||||
if torch_version_numeric < (2, 3):
|
||||
return False
|
||||
|
||||
if WINDOWS:
|
||||
if torch_version_numeric < (2, 4):
|
||||
return False
|
||||
|
||||
elif is_musa():
|
||||
if props.major >= 3:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
@ -1543,7 +1556,7 @@ def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
if is_amd() or is_nvidia() or is_musa():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
|
||||
@ -28,3 +28,4 @@ kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
torchada>=0.1.11
|
||||
|
||||
Loading…
Reference in New Issue
Block a user