mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 03:40:16 +08:00
Update zluda.py
This commit is contained in:
parent
ad2fa1a675
commit
eee802e685
@ -70,6 +70,18 @@ if is_zluda:
|
|||||||
torch.jit.script = z_jit
|
torch.jit.script = z_jit
|
||||||
# ------------------- End Audio Patch -------------------
|
# ------------------- End Audio Patch -------------------
|
||||||
|
|
||||||
|
# ------------------- Top-K Fallback Patch -------------------
|
||||||
|
if is_zluda:
|
||||||
|
_topk = torch.topk
|
||||||
|
|
||||||
|
def safe_topk(input: torch.Tensor, *args, **kwargs):
|
||||||
|
device = input.device
|
||||||
|
values, indices = _topk(input.cpu(), *args, **kwargs)
|
||||||
|
return torch.return_types.topk((values.to(device), indices.to(device),))
|
||||||
|
|
||||||
|
torch.topk = safe_topk
|
||||||
|
# ------------------- End Top-K Patch -------------------
|
||||||
|
|
||||||
# ------------------- ONNX Runtime Patch -------------------
|
# ------------------- ONNX Runtime Patch -------------------
|
||||||
try:
|
try:
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
@ -86,7 +98,18 @@ try:
|
|||||||
|
|
||||||
# Patch ONLY the _pybind_state version (used during session creation)
|
# Patch ONLY the _pybind_state version (used during session creation)
|
||||||
ort.capi._pybind_state.get_available_providers = filtered_providers
|
ort.capi._pybind_state.get_available_providers = filtered_providers
|
||||||
# DO NOT override ort.get_available_providers — leave that public interface alone
|
|
||||||
|
# Wrap InferenceSession to force CPU provider when CUDA is explicitly requested
|
||||||
|
OriginalSession = ort.InferenceSession
|
||||||
|
|
||||||
|
class SafeInferenceSession(OriginalSession):
|
||||||
|
def __init__(self, *args, providers=None, **kwargs):
|
||||||
|
if providers and "CUDAExecutionProvider" in providers:
|
||||||
|
print(" :: Forcing ONNX to use CPUExecutionProvider instead of CUDA.")
|
||||||
|
providers = ["CPUExecutionProvider"]
|
||||||
|
super().__init__(*args, providers=providers, **kwargs)
|
||||||
|
|
||||||
|
ort.InferenceSession = SafeInferenceSession
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(" :: ONNX Runtime not installed — skipping patch.")
|
print(" :: ONNX Runtime not installed — skipping patch.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -112,4 +135,4 @@ if is_zluda:
|
|||||||
else:
|
else:
|
||||||
print(f" :: CUDA device detected: {zluda_device_name or 'None'}")
|
print(f" :: CUDA device detected: {zluda_device_name or 'None'}")
|
||||||
print("***--------------------------------------------------------***\n")
|
print("***--------------------------------------------------------***\n")
|
||||||
# ------------------- End Zluda patching -------------------
|
# ------------------- End Zluda detection -------------------
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user