diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..f97baaeb4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,8 @@ import importlib import platform import weakref import gc +from pathlib import Path + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -193,7 +195,7 @@ def get_total_memory(dev=None, torch_total_too=False): dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total + mem_total = _CGROUP_LIMIT_BYTES or psutil.virtual_memory().total mem_total_torch = mem_total else: if directml_enabled: @@ -235,8 +237,28 @@ def mac_version(): except: return None + +_CG = Path("/sys/fs/cgroup") +if (_CG / "memory.max").exists(): # cgroup v2 + _LIMIT_F = _CG / "memory.max" + _USED_F = _CG / "memory.current" +else: # cgroup v1 + _LIMIT_F = _CG / "memory/memory.limit_in_bytes" + _USED_F = _CG / "memory/memory.usage_in_bytes" + +def _get_cgroup_value(p: Path): + try: + v = int(p.read_text().strip()) + if v == 0 or v >= (1 << 60): + return None # 'max' in v2 shows up as 2**63-1 or 0, treat both as unlimited + return v + except (FileNotFoundError, PermissionError, ValueError): + return None + +_CGROUP_LIMIT_BYTES = _get_cgroup_value(_LIMIT_F) + total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) -total_ram = psutil.virtual_memory().total / (1024 * 1024) +total_ram = (_CGROUP_LIMIT_BYTES or psutil.virtual_memory().total) / (1024 * 1024) logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) try: @@ -1261,7 +1283,14 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == 'cpu': + used = _get_cgroup_value(_USED_F) if _CGROUP_LIMIT_BYTES is not None else None + if _CGROUP_LIMIT_BYTES is not None and used is not None: + mem_free_total = max(_CGROUP_LIMIT_BYTES - used, 0) + else: + mem_free_total = psutil.virtual_memory().available + mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and dev.type == 'mps': mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: