This commit is contained in:
Alexander Piskun 2025-12-03 11:10:50 -05:00 committed by GitHub
commit 247accfcae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,8 @@ import importlib
import platform
import weakref
import gc
from pathlib import Path
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -193,7 +195,7 @@ def get_total_memory(dev=None, torch_total_too=False):
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total = _CGROUP_LIMIT_BYTES or psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
@ -235,8 +237,28 @@ def mac_version():
except:
return None
_CG = Path("/sys/fs/cgroup")
if (_CG / "memory.max").exists(): # cgroup v2
_LIMIT_F = _CG / "memory.max"
_USED_F = _CG / "memory.current"
else: # cgroup v1
_LIMIT_F = _CG / "memory/memory.limit_in_bytes"
_USED_F = _CG / "memory/memory.usage_in_bytes"
def _get_cgroup_value(p: Path):
try:
v = int(p.read_text().strip())
if v == 0 or v >= (1 << 60):
return None # 'max' in v2 shows up as 2**63-1 or 0, treat both as unlimited
return v
except (FileNotFoundError, PermissionError, ValueError):
return None
_CGROUP_LIMIT_BYTES = _get_cgroup_value(_LIMIT_F)
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
total_ram = (_CGROUP_LIMIT_BYTES or psutil.virtual_memory().total) / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
@ -1261,7 +1283,14 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and dev.type == 'cpu':
used = _get_cgroup_value(_USED_F) if _CGROUP_LIMIT_BYTES is not None else None
if _CGROUP_LIMIT_BYTES is not None and used is not None:
mem_free_total = max(_CGROUP_LIMIT_BYTES - used, 0)
else:
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and dev.type == 'mps':
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else: