This commit is contained in:
Ahmad Afzal 2026-03-28 19:36:53 +00:00 committed by GitHub
commit 90a6f9856c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -47,13 +47,14 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60",
"Tesla P40", "Tesla P4", "NVIDIA L4", "NVIDIA A10G", "GRID A800D",
}
def cuda_malloc_supported():
try:
names = get_gpu_names()
except:
except Exception:
names = set()
for x in names:
if "NVIDIA" in x:
@ -74,7 +75,7 @@ try:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
except:
except Exception:
pass
if not args.cuda_malloc:
@ -82,7 +83,7 @@ if not args.cuda_malloc:
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported()
except:
except Exception:
pass
if args.disable_cuda_malloc: