diff --git a/comfy/model_management.py b/comfy/model_management.py index 803aa8a2a..810cfbff8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import sys +import warnings from enum import Enum from threading import RLock from typing import Literal @@ -101,7 +102,11 @@ def get_torch_device(): if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) else: - return torch.device(torch.cuda.current_device()) + try: + return torch.device(torch.cuda.current_device()) + except: + warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device") + return torch.device("cpu") def get_total_memory(dev=None, torch_total_too=False):