fix(memory, docker): support for reading cgroup data

Signed-off-by: bigcat88 <bigcat88@icloud.com>
This commit is contained in:
Alexander Piskun 2025-06-13 15:12:52 +03:00 committed by bigcat88
parent c69af655aa
commit 673c63572d
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -25,6 +25,9 @@ import sys
import platform
import weakref
import gc
from pathlib import Path
from functools import lru_cache
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -177,7 +180,7 @@ def get_total_memory(dev=None, torch_total_too=False):
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total = _cgroup_limit_bytes() or psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
@ -218,6 +221,35 @@ def mac_version():
except:
return None
_CG = Path("/sys/fs/cgroup")
if (_CG / "memory.max").exists(): # cgroup v2
_LIMIT_F = _CG / "memory.max"
_USED_F = _CG / "memory.current"
else: # cgroup v1
_LIMIT_F = _CG / "memory/memory.limit_in_bytes"
_USED_F = _CG / "memory/memory.usage_in_bytes"
@lru_cache(maxsize=None) # the hard limit never changes
def _cgroup_limit_bytes():
return _read_int(_LIMIT_F)
def _cgroup_used_bytes():
return _read_int(_USED_F)
def _read_int(p: Path):
try:
v = int(p.read_text().strip())
if v == 0 or v >= (1 << 60):
return None # 'max' in v2 shows up as 2**63-1 or 0, treat both as unlimited
return v
except (FileNotFoundError, PermissionError, ValueError):
return None
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
@ -1081,7 +1113,15 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and dev.type == 'cpu':
limit = _cgroup_limit_bytes()
used = _cgroup_used_bytes() if limit is not None else None
if limit is not None and used is not None:
mem_free_total = max(limit - used, 0)
else:
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and dev.type == 'mps':
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else: