Update zluda.py

This commit is contained in:
patientx 2025-04-16 00:53:18 +03:00 committed by GitHub
parent ad2fa1a675
commit eee802e685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -70,6 +70,18 @@ if is_zluda:
torch.jit.script = z_jit
# ------------------- End Audio Patch -------------------
# ------------------- Top-K Fallback Patch -------------------
if is_zluda:
_topk = torch.topk
def safe_topk(input: torch.Tensor, *args, **kwargs):
device = input.device
values, indices = _topk(input.cpu(), *args, **kwargs)
return torch.return_types.topk((values.to(device), indices.to(device),))
torch.topk = safe_topk
# ------------------- End Top-K Patch -------------------
# ------------------- ONNX Runtime Patch -------------------
try:
import onnxruntime as ort
@ -86,7 +98,18 @@ try:
# Patch ONLY the _pybind_state version (used during session creation)
ort.capi._pybind_state.get_available_providers = filtered_providers
# DO NOT override ort.get_available_providers — leave that public interface alone
# Wrap InferenceSession to force CPU provider when CUDA is explicitly requested
OriginalSession = ort.InferenceSession
class SafeInferenceSession(OriginalSession):
def __init__(self, *args, providers=None, **kwargs):
if providers and "CUDAExecutionProvider" in providers:
print(" :: Forcing ONNX to use CPUExecutionProvider instead of CUDA.")
providers = ["CPUExecutionProvider"]
super().__init__(*args, providers=providers, **kwargs)
ort.InferenceSession = SafeInferenceSession
except ImportError:
print(" :: ONNX Runtime not installed — skipping patch.")
except Exception as e:
@ -112,4 +135,4 @@ if is_zluda:
else:
print(f" :: CUDA device detected: {zluda_device_name or 'None'}")
print("***--------------------------------------------------------***\n")
# ------------------- End Zluda patching -------------------
# ------------------- End Zluda detection -------------------