diff --git a/cuda_malloc.py b/cuda_malloc.py index 11bf7c039..595af2ffe 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -56,12 +56,14 @@ def cuda_malloc_supported(): except: names = set() for x in names: - if "NVIDIA" in x: + if "AMD" in x: + return False + elif "NVIDIA" in x: for b in blacklist: if b in x: return False - return False -#We don't need malloc at all with amd gpu's. So disabling all together. + return True +#returned to older version. if not args.cuda_malloc: