inserted missing is_amd() check

This commit is contained in:
Christopher Anderson 2025-10-29 23:48:26 +11:00
parent 995c07341e
commit a4eb32a4ff

View File

@ -334,6 +334,7 @@ SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
try: try:
if is_amd():
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = os.environ.get("TORCH_AMD_CUDNN_ENABLED", "0").strip().lower() not in { torch.backends.cudnn.enabled = os.environ.get("TORCH_AMD_CUDNN_ENABLED", "0").strip().lower() not in {