feat(memory): support container run in k8s

The relevant logic from github.com/google/cadvisor and
github.com/opencontainers/cgroups.
This commit is contained in:
j2gg0s 2025-06-13 12:19:50 +08:00
parent c6529c0d77
commit 250ae5db22

View File

@ -171,13 +171,27 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
cgroup_path_memory_max = None
if os.path.isfile('/sys/fs/cgroup/memory.max'):
cgroup_path_memory_max = '/sys/fs/cgroup/memory.max'
elif os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'):
cgroup_path_memory_max = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
mem_total = -1
if cgroup_path_memory_max is not None:
with open(cgroup_path_memory_max, 'r') as f:
raw = f.read()
if raw.isdigit():
# NOTE: maybe max or empty
mem_total = int(raw)
if mem_total < 0:
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
@ -1076,13 +1090,28 @@ def force_upcast_attention_dtype():
else:
return None
cgroup_path_memory_used = None
if os.path.isfile('/sys/fs/cgroup/memory.current'):
cgroup_path_memory_used = '/sys/fs/cgroup/memory.current'
elif os.path.isfile('/sys/fs/cgroup/memory/memory.usage_in_bytes'):
cgroup_path_memory_used = '/sys/fs/cgroup/memory/memory.usage_in_bytes'
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_used = -1
if cgroup_path_memory_used is not None:
with open(cgroup_path_memory_used, 'r') as f:
raw = f.read()
if raw.isdigit():
mem_used = int(f.read())
if mem_used > 0:
mem_free_total = get_total_memory(dev) - mem_used
else:
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if directml_enabled: