diff --git a/comfy/model_management.py b/comfy/model_management.py index 054291432..067b61b62 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -25,6 +25,9 @@ import sys import platform import weakref import gc +from pathlib import Path +from functools import lru_cache + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -177,7 +180,7 @@ def get_total_memory(dev=None, torch_total_too=False): dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total + mem_total = _cgroup_limit_bytes() or psutil.virtual_memory().total mem_total_torch = mem_total else: if directml_enabled: @@ -218,6 +221,35 @@ def mac_version(): except: return None + +_CG = Path("/sys/fs/cgroup") +if (_CG / "memory.max").exists(): # cgroup v2 + _LIMIT_F = _CG / "memory.max" + _USED_F = _CG / "memory.current" +else: # cgroup v1 + _LIMIT_F = _CG / "memory/memory.limit_in_bytes" + _USED_F = _CG / "memory/memory.usage_in_bytes" + + +@lru_cache(maxsize=None) # the hard limit never changes +def _cgroup_limit_bytes(): + return _read_int(_LIMIT_F) + + +def _cgroup_used_bytes(): + return _read_int(_USED_F) + + +def _read_int(p: Path): + try: + v = int(p.read_text().strip()) + if v == 0 or v >= (1 << 60): + return None # 'max' in v2 shows up as 2**63-1 or 0, treat both as unlimited + return v + except (FileNotFoundError, PermissionError, ValueError): + return None + + total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) @@ -1081,7 +1113,15 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == 'cpu': + limit = _cgroup_limit_bytes() + used = _cgroup_used_bytes() if limit is not None else None + if limit is not None and used is not None: + mem_free_total = max(limit - used, 0) + else: + mem_free_total = psutil.virtual_memory().available + mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and dev.type == 'mps': mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: