rebased onto upstream master (woops)

This commit is contained in:
Christopher Anderson 2025-08-25 09:32:34 +10:00
parent 7eda4587be
commit 3504eeeb4a

View File

@ -676,13 +676,12 @@ def do_hijack():
print(" :: Configuring PyTorch backends...") print(" :: Configuring PyTorch backends...")
torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp = do_nothing torch.backends.cuda.enable_mem_efficient_sdp = do_nothing
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = os.environ.get("TORCH_BACKENDS_CUDNN_ENABLED", "1").strip().lower() not in {"0", "off", "false", "disable", "disabled", "no"} torch.backends.cudnn.enabled = os.environ.get("TORCH_BACKENDS_CUDNN_ENABLED", "1").strip().lower() not in {"0", "off", "false", "disable", "disabled", "no"}
if torch.backends.cudnn.enabled: if torch.backends.cudnn.enabled:
print(" :: Enabled cuDNN") print(" :: Enabled cuDNN")
else: else:
print(" :: Disabled cuDNN") print(" :: Disabled cuDNN")
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends.cuda, "enable_flash_sdp"): if hasattr(torch.backends.cuda, "enable_flash_sdp"):
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
print(" :: Disabled CUDA flash attention") print(" :: Disabled CUDA flash attention")