mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge f0caa15a17 into 3cd7b32f1b
This commit is contained in:
commit
19c2d83ecb
@ -21,8 +21,15 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
else:
|
else:
|
||||||
device = pos.device
|
device = pos.device
|
||||||
|
|
||||||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
if device.type == "musa":
|
||||||
omega = 1.0 / (theta**scale)
|
# XXX (MUSA): Unsupported tensor dtype in Neg: Double
|
||||||
|
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float32, device=device)
|
||||||
|
if not isinstance(theta, torch.Tensor):
|
||||||
|
theta = torch.tensor(theta, dtype=torch.float32, device=device)
|
||||||
|
omega = torch.exp(-scale * torch.log(theta + 1e-6))
|
||||||
|
else:
|
||||||
|
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
|
||||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
|
|||||||
@ -139,6 +139,12 @@ try:
|
|||||||
except:
|
except:
|
||||||
ixuca_available = False
|
ixuca_available = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torchada # noqa: F401
|
||||||
|
musa_available = hasattr(torch, "musa") and torch.musa.is_available()
|
||||||
|
except:
|
||||||
|
musa_available = False
|
||||||
|
|
||||||
if args.cpu:
|
if args.cpu:
|
||||||
cpu_state = CPUState.CPU
|
cpu_state = CPUState.CPU
|
||||||
|
|
||||||
@ -146,27 +152,24 @@ def is_intel_xpu():
|
|||||||
global cpu_state
|
global cpu_state
|
||||||
global xpu_available
|
global xpu_available
|
||||||
if cpu_state == CPUState.GPU:
|
if cpu_state == CPUState.GPU:
|
||||||
if xpu_available:
|
return xpu_available
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_ascend_npu():
|
def is_ascend_npu():
|
||||||
global npu_available
|
global npu_available
|
||||||
if npu_available:
|
return npu_available
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_mlu():
|
def is_mlu():
|
||||||
global mlu_available
|
global mlu_available
|
||||||
if mlu_available:
|
return mlu_available
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_ixuca():
|
def is_ixuca():
|
||||||
global ixuca_available
|
global ixuca_available
|
||||||
if ixuca_available:
|
return ixuca_available
|
||||||
return True
|
|
||||||
return False
|
def is_musa():
|
||||||
|
global musa_available
|
||||||
|
return musa_available
|
||||||
|
|
||||||
def get_torch_device():
|
def get_torch_device():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@ -311,7 +314,7 @@ def amd_min_version(device=None, min_rdna_version=0):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
if is_nvidia():
|
if is_nvidia() or is_musa():
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
@ -320,7 +323,7 @@ if args.use_pytorch_cross_attention:
|
|||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia() or is_musa():
|
||||||
if torch_version_numeric[0] >= 2:
|
if torch_version_numeric[0] >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@ -375,7 +378,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
|||||||
|
|
||||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||||
try:
|
try:
|
||||||
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
if (is_nvidia() or is_amd() or is_musa()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||||
logging.info("Enabled fp16 accumulation.")
|
logging.info("Enabled fp16 accumulation.")
|
||||||
@ -1020,7 +1023,7 @@ if args.async_offload is not None:
|
|||||||
NUM_STREAMS = args.async_offload
|
NUM_STREAMS = args.async_offload
|
||||||
else:
|
else:
|
||||||
# Enable by default on Nvidia and AMD
|
# Enable by default on Nvidia and AMD
|
||||||
if is_nvidia() or is_amd():
|
if is_nvidia() or is_amd() or is_musa():
|
||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
|
|
||||||
if args.disable_async_offload:
|
if args.disable_async_offload:
|
||||||
@ -1117,7 +1120,7 @@ PINNED_MEMORY = {}
|
|||||||
TOTAL_PINNED_MEMORY = 0
|
TOTAL_PINNED_MEMORY = 0
|
||||||
MAX_PINNED_MEMORY = -1
|
MAX_PINNED_MEMORY = -1
|
||||||
if not args.disable_pinned_memory:
|
if not args.disable_pinned_memory:
|
||||||
if is_nvidia() or is_amd():
|
if is_nvidia() or is_amd() or is_musa():
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||||
else:
|
else:
|
||||||
@ -1261,6 +1264,8 @@ def pytorch_attention_flash_attention():
|
|||||||
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
|
||||||
if is_ixuca():
|
if is_ixuca():
|
||||||
return True
|
return True
|
||||||
|
if is_musa():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
@ -1392,6 +1397,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_musa():
|
||||||
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 8:
|
if props.major >= 8:
|
||||||
return True
|
return True
|
||||||
@ -1462,6 +1470,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if is_musa():
|
||||||
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
|
||||||
if is_mlu():
|
if is_mlu():
|
||||||
@ -1484,25 +1495,27 @@ def supports_fp8_compute(device=None):
|
|||||||
if SUPPORT_FP8_OPS:
|
if SUPPORT_FP8_OPS:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not is_nvidia():
|
|
||||||
return False
|
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 9:
|
if is_nvidia():
|
||||||
return True
|
if props.major >= 9:
|
||||||
if props.major < 8:
|
return True
|
||||||
return False
|
if props.major < 8:
|
||||||
if props.minor < 9:
|
return False
|
||||||
return False
|
if props.minor < 9:
|
||||||
|
|
||||||
if torch_version_numeric < (2, 3):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if WINDOWS:
|
|
||||||
if torch_version_numeric < (2, 4):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
if torch_version_numeric < (2, 3):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if WINDOWS:
|
||||||
|
if torch_version_numeric < (2, 4):
|
||||||
|
return False
|
||||||
|
|
||||||
|
elif is_musa():
|
||||||
|
if props.major >= 3:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def supports_nvfp4_compute(device=None):
|
def supports_nvfp4_compute(device=None):
|
||||||
if not is_nvidia():
|
if not is_nvidia():
|
||||||
@ -1553,7 +1566,7 @@ def unload_all_models():
|
|||||||
free_memory(1e30, get_torch_device())
|
free_memory(1e30, get_torch_device())
|
||||||
|
|
||||||
def debug_memory_summary():
|
def debug_memory_summary():
|
||||||
if is_amd() or is_nvidia():
|
if is_amd() or is_nvidia() or is_musa():
|
||||||
return torch.cuda.memory.memory_summary()
|
return torch.cuda.memory.memory_summary()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@ -28,3 +28,4 @@ kornia>=0.7.1
|
|||||||
spandrel
|
spandrel
|
||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
|
torchada>=0.1.11
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user