mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Patch torch device code here
This commit is contained in:
parent
4eb66f8a0a
commit
cb45b86b63
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from threading import RLock
|
from threading import RLock
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -101,7 +102,11 @@ def get_torch_device():
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return torch.device("xpu", torch.xpu.current_device())
|
return torch.device("xpu", torch.xpu.current_device())
|
||||||
else:
|
else:
|
||||||
return torch.device(torch.cuda.current_device())
|
try:
|
||||||
|
return torch.device(torch.cuda.current_device())
|
||||||
|
except:
|
||||||
|
warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device")
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
def get_total_memory(dev=None, torch_total_too=False):
|
def get_total_memory(dev=None, torch_total_too=False):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user