mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-12 07:10:52 +08:00
Update zluda.py
This commit is contained in:
parent
ad2fa1a675
commit
eee802e685
@ -70,6 +70,18 @@ if is_zluda:
|
||||
torch.jit.script = z_jit
|
||||
# ------------------- End Audio Patch -------------------
|
||||
|
||||
# ------------------- Top-K Fallback Patch -------------------
|
||||
if is_zluda:
|
||||
_topk = torch.topk
|
||||
|
||||
def safe_topk(input: torch.Tensor, *args, **kwargs):
|
||||
device = input.device
|
||||
values, indices = _topk(input.cpu(), *args, **kwargs)
|
||||
return torch.return_types.topk((values.to(device), indices.to(device),))
|
||||
|
||||
torch.topk = safe_topk
|
||||
# ------------------- End Top-K Patch -------------------
|
||||
|
||||
# ------------------- ONNX Runtime Patch -------------------
|
||||
try:
|
||||
import onnxruntime as ort
|
||||
@ -86,7 +98,18 @@ try:
|
||||
|
||||
# Patch ONLY the _pybind_state version (used during session creation)
|
||||
ort.capi._pybind_state.get_available_providers = filtered_providers
|
||||
# DO NOT override ort.get_available_providers — leave that public interface alone
|
||||
|
||||
# Wrap InferenceSession to force CPU provider when CUDA is explicitly requested
|
||||
OriginalSession = ort.InferenceSession
|
||||
|
||||
class SafeInferenceSession(OriginalSession):
|
||||
def __init__(self, *args, providers=None, **kwargs):
|
||||
if providers and "CUDAExecutionProvider" in providers:
|
||||
print(" :: Forcing ONNX to use CPUExecutionProvider instead of CUDA.")
|
||||
providers = ["CPUExecutionProvider"]
|
||||
super().__init__(*args, providers=providers, **kwargs)
|
||||
|
||||
ort.InferenceSession = SafeInferenceSession
|
||||
except ImportError:
|
||||
print(" :: ONNX Runtime not installed — skipping patch.")
|
||||
except Exception as e:
|
||||
@ -112,4 +135,4 @@ if is_zluda:
|
||||
else:
|
||||
print(f" :: CUDA device detected: {zluda_device_name or 'None'}")
|
||||
print("***--------------------------------------------------------***\n")
|
||||
# ------------------- End Zluda patching -------------------
|
||||
# ------------------- End Zluda detection -------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user