Patch torch device code here

This commit is contained in:
doctorpangloss 2024-05-17 07:19:15 -07:00
parent 4eb66f8a0a
commit cb45b86b63

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging import logging
import sys import sys
import warnings
from enum import Enum from enum import Enum
from threading import RLock from threading import RLock
from typing import Literal from typing import Literal
@ -101,7 +102,11 @@ def get_torch_device():
if is_intel_xpu(): if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device()) return torch.device("xpu", torch.xpu.current_device())
else: else:
return torch.device(torch.cuda.current_device()) try:
return torch.device(torch.cuda.current_device())
except:
warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device")
return torch.device("cpu")
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):