diff --git a/comfy/zluda.py b/comfy/zluda.py index 8c9b35421..f3176791c 100644 --- a/comfy/zluda.py +++ b/comfy/zluda.py @@ -70,6 +70,18 @@ if is_zluda: torch.jit.script = z_jit # ------------------- End Audio Patch ------------------- +# ------------------- Top-K Fallback Patch ------------------- +if is_zluda: + _topk = torch.topk + + def safe_topk(input: torch.Tensor, *args, **kwargs): + device = input.device + values, indices = _topk(input.cpu(), *args, **kwargs) + return torch.return_types.topk((values.to(device), indices.to(device),)) + + torch.topk = safe_topk +# ------------------- End Top-K Patch ------------------- + # ------------------- ONNX Runtime Patch ------------------- try: import onnxruntime as ort @@ -86,7 +98,18 @@ try: # Patch ONLY the _pybind_state version (used during session creation) ort.capi._pybind_state.get_available_providers = filtered_providers - # DO NOT override ort.get_available_providers — leave that public interface alone + + # Wrap InferenceSession to force CPU provider when CUDA is explicitly requested + OriginalSession = ort.InferenceSession + + class SafeInferenceSession(OriginalSession): + def __init__(self, *args, providers=None, **kwargs): + if providers and "CUDAExecutionProvider" in providers: + print(" :: Forcing ONNX to use CPUExecutionProvider instead of CUDA.") + providers = ["CPUExecutionProvider"] + super().__init__(*args, providers=providers, **kwargs) + + ort.InferenceSession = SafeInferenceSession except ImportError: print(" :: ONNX Runtime not installed — skipping patch.") except Exception as e: @@ -112,4 +135,4 @@ if is_zluda: else: print(f" :: CUDA device detected: {zluda_device_name or 'None'}") print("***--------------------------------------------------------***\n") -# ------------------- End Zluda patching ------------------- +# ------------------- End Zluda detection -------------------