Patch torch device code here

This commit is contained in:
doctorpangloss 2024-05-17 07:19:15 -07:00
parent 4eb66f8a0a
commit cb45b86b63

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import sys
import warnings
from enum import Enum
from threading import RLock
from typing import Literal
@ -101,7 +102,11 @@ def get_torch_device():
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
else:
return torch.device(torch.cuda.current_device())
try:
return torch.device(torch.cuda.current_device())
except:
warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device")
return torch.device("cpu")
def get_total_memory(dev=None, torch_total_too=False):