diff --git a/comfy/model_management.py b/comfy/model_management.py index e5554e225..50ef91528 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -188,13 +188,27 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +cgroup_path_memory_max = None +if os.path.isfile('/sys/fs/cgroup/memory.max'): + cgroup_path_memory_max = '/sys/fs/cgroup/memory.max' +elif os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'): + cgroup_path_memory_max = '/sys/fs/cgroup/memory/memory.limit_in_bytes' + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total + mem_total = -1 + if cgroup_path_memory_max is not None: + with open(cgroup_path_memory_max, 'r') as f: + raw = f.read() + if raw.isdigit(): + # NOTE: maybe max or empty + mem_total = int(raw) + if mem_total < 0: + mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: if directml_enabled: @@ -1259,13 +1273,28 @@ def force_upcast_attention_dtype(): else: return None +cgroup_path_memory_used = None +if os.path.isfile('/sys/fs/cgroup/memory.current'): + cgroup_path_memory_used = '/sys/fs/cgroup/memory.current' +elif os.path.isfile('/sys/fs/cgroup/memory/memory.usage_in_bytes'): + cgroup_path_memory_used = '/sys/fs/cgroup/memory/memory.usage_in_bytes' + def get_free_memory(dev=None, torch_free_too=False): global directml_enabled if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_free_total = psutil.virtual_memory().available + mem_used = -1 + if cgroup_path_memory_used is not None: + with open(cgroup_path_memory_used, 'r') as f: + raw = f.read() + if raw.isdigit(): + mem_used = int(f.read()) + if mem_used > 0: + mem_free_total = get_total_memory(dev) - mem_used + else: + mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: if directml_enabled: