mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Only disable cudnn on newer AMD GPUs. (#10437)
This commit is contained in:
parent
560b1bdfca
commit
9cdc64998f
@ -330,15 +330,21 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||||
|
|
||||||
|
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
except:
|
except:
|
||||||
rocm_version = (6, -1)
|
rocm_version = (6, -1)
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
|
||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
@ -1331,7 +1337,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||||
if manual_cast:
|
if manual_cast:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user