Update zluda.py

This commit is contained in:
patientx 2025-06-02 02:55:19 +03:00 committed by GitHub
parent 06090270ad
commit 9aeff135b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -192,26 +192,26 @@ class DeviceProperties:
return DeviceProperties.PROPERTIES_OVERRIDE[name] return DeviceProperties.PROPERTIES_OVERRIDE[name]
return getattr(self.internal, name) return getattr(self.internal, name)
# ------------------- Audio Ops Patch ------------------- # # ------------------- Audio Ops Patch -------------------
if is_zluda: # if is_zluda:
_torch_stft = torch.stft # _torch_stft = torch.stft
_torch_istft = torch.istft # _torch_istft = torch.istft
def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): # def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) # return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): # def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs):
return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) # return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device)
def z_jit(f, *_, **__): # def z_jit(f, *_, **__):
f.graph = torch._C.Graph() # f.graph = torch._C.Graph()
return f # return f
torch._dynamo.config.suppress_errors = True # torch._dynamo.config.suppress_errors = True
torch.stft = z_stft # torch.stft = z_stft
torch.istft = z_istft # torch.istft = z_istft
torch.jit.script = z_jit # torch.jit.script = z_jit
# ------------------- End Audio Patch ------------------- # # ------------------- End Audio Patch -------------------
# ------------------- Top-K Fallback Patch ------------------- # ------------------- Top-K Fallback Patch -------------------
if is_zluda: if is_zluda: