From 250ae5db2200b6102d33028aecdf4b76b5bd383d Mon Sep 17 00:00:00 2001 From: j2gg0s Date: Fri, 13 Jun 2025 12:19:50 +0800 Subject: [PATCH] feat(memory): support container run in k8s The relevant logic from github.com/google/cadvisor and github.com/opencontainers/cgroups. --- comfy/model_management.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 054291432..be68d5682 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -171,13 +171,27 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +cgroup_path_memory_max = None +if os.path.isfile('/sys/fs/cgroup/memory.max'): + cgroup_path_memory_max = '/sys/fs/cgroup/memory.max' +elif os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'): + cgroup_path_memory_max = '/sys/fs/cgroup/memory/memory.limit_in_bytes' + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total + mem_total = -1 + if cgroup_path_memory_max is not None: + with open(cgroup_path_memory_max, 'r') as f: + raw = f.read() + if raw.isdigit(): + # NOTE: maybe max or empty + mem_total = int(raw) + if mem_total < 0: + mem_total = psutil.virtual_memory().total mem_total_torch = mem_total else: if directml_enabled: @@ -1076,13 +1090,28 @@ def force_upcast_attention_dtype(): else: return None +cgroup_path_memory_used = None +if os.path.isfile('/sys/fs/cgroup/memory.current'): + cgroup_path_memory_used = '/sys/fs/cgroup/memory.current' +elif os.path.isfile('/sys/fs/cgroup/memory/memory.usage_in_bytes'): + cgroup_path_memory_used = '/sys/fs/cgroup/memory/memory.usage_in_bytes' + def get_free_memory(dev=None, torch_free_too=False): global directml_enabled if dev is None: dev = get_torch_device() if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_free_total = psutil.virtual_memory().available + mem_used = -1 + if cgroup_path_memory_used is not None: + with open(cgroup_path_memory_used, 'r') as f: + raw = f.read() + if raw.isdigit(): + mem_used = int(f.read()) + if mem_used > 0: + mem_free_total = get_total_memory(dev) - mem_used + else: + mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: if directml_enabled: