diff --git a/cuda_malloc.py b/cuda_malloc.py index 7450c12d3..ee2bc4b69 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -56,14 +56,11 @@ def cuda_malloc_supported(): except: names = set() for x in names: - if "AMD" in x: - return False - elif "NVIDIA" in x: + if "NVIDIA" in x: for b in blacklist: if b in x: return False - return False -#We don't need malloc at all with amd gpu's. So disabling all together. + return True version = ""