had to move cudnn disablement up much higher

This commit is contained in:
Christopher Anderson 2025-08-25 14:11:54 +10:00
parent cd3d60254b
commit 1b9a3b12c2

View File

@ -430,6 +430,14 @@ try:
import triton
import triton.language as tl
print(" :: Triton core imported successfully")
# This needs to be up here, so it can disable cudnn before anything can even think about using it
torch.backends.cudnn.enabled = os.environ.get("TORCH_BACKENDS_CUDNN_ENABLED", "1").strip().lower() not in {"0", "off", "false", "disable", "disabled", "no"}
if torch.backends.cudnn.enabled:
print(" :: Enabled cuDNN")
else:
print(" :: Disabled cuDNN")
torch.backends.cudnn.benchmark = False
@triton.jit
def _zluda_kernel_test(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
@ -676,12 +684,6 @@ def do_hijack():
print(" :: Configuring PyTorch backends...")
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
torch.backends.cudnn.enabled = os.environ.get("TORCH_BACKENDS_CUDNN_ENABLED", "1").strip().lower() not in {"0", "off", "false", "disable", "disabled", "no"}
if torch.backends.cudnn.enabled:
print(" :: Enabled cuDNN")
else:
print(" :: Disabled cuDNN")
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends.cuda, "enable_flash_sdp"):
torch.backends.cuda.enable_flash_sdp(True)
print(" :: Disabled CUDA flash attention")