Update zluda.py

This commit is contained in:
patientx 2025-06-02 02:55:19 +03:00 committed by GitHub
parent 06090270ad
commit 9aeff135b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -192,26 +192,26 @@ class DeviceProperties:
return DeviceProperties.PROPERTIES_OVERRIDE[name]
return getattr(self.internal, name)
# ------------------- Audio Ops Patch -------------------
if is_zluda:
_torch_stft = torch.stft
_torch_istft = torch.istft
# # ------------------- Audio Ops Patch -------------------
# if is_zluda:
# _torch_stft = torch.stft
# _torch_istft = torch.istft
def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
# def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
# return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
# def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
# return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
def z_jit(f, *_, **__):
f.graph = torch._C.Graph()
return f
# def z_jit(f, *_, **__):
# f.graph = torch._C.Graph()
# return f
torch._dynamo.config.suppress_errors = True
torch.stft = z_stft
torch.istft = z_istft
torch.jit.script = z_jit
# ------------------- End Audio Patch -------------------
# torch._dynamo.config.suppress_errors = True
# torch.stft = z_stft
# torch.istft = z_istft
# torch.jit.script = z_jit
# # ------------------- End Audio Patch -------------------
# ------------------- Top-K Fallback Patch -------------------
if is_zluda: