From c3fc894ce2e04de431d5e1904a67cecdf7eea711 Mon Sep 17 00:00:00 2001 From: patientx Date: Tue, 14 Jan 2025 19:54:44 +0300 Subject: [PATCH] Add files via upload --- comfy/zluda.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 comfy/zluda.py diff --git a/comfy/zluda.py b/comfy/zluda.py new file mode 100644 index 000000000..4f52ccf39 --- /dev/null +++ b/comfy/zluda.py @@ -0,0 +1,33 @@ +try: + torch_device_name = get_torch_device_name(get_torch_device()) + + if "[ZLUDA]" in torch_device_name: + _torch_stft = torch.stft + _torch_istft = torch.istft + + def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): + return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) + + def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): + return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) + + def z_jit(f, *_, **__): + f.graph = torch._C.Graph() + return f + + # hijacks + torch.stft = z_stft + torch.istft = z_istft + torch.jit.script = z_jit + print(" ") + print("***----------------------ZLUDA--------------------------***") + print(" :: ZLUDA detected, disabling non-supported functions.") + torch.backends.cudnn.enabled = False + print(" :: (cuDNN, flash_sdp, mem_efficient_sdp disabled) ") + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) + + print("***-----------------------------------------------------***") + print(" :: Device:", torch_device_name) + print(" ") \ No newline at end of file