This commit is contained in:
j2gg0s 2025-12-29 11:55:21 +08:00 committed by GitHub
commit 7f03c2b750
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -188,12 +188,26 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
cgroup_path_memory_max = None
if os.path.isfile('/sys/fs/cgroup/memory.max'):
cgroup_path_memory_max = '/sys/fs/cgroup/memory.max'
elif os.path.isfile('/sys/fs/cgroup/memory/memory.limit_in_bytes'):
cgroup_path_memory_max = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = -1
if cgroup_path_memory_max is not None:
with open(cgroup_path_memory_max, 'r') as f:
raw = f.read()
if raw.isdigit():
# NOTE: maybe max or empty
mem_total = int(raw)
if mem_total < 0:
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
@ -1259,12 +1273,27 @@ def force_upcast_attention_dtype():
else:
return None
cgroup_path_memory_used = None
if os.path.isfile('/sys/fs/cgroup/memory.current'):
cgroup_path_memory_used = '/sys/fs/cgroup/memory.current'
elif os.path.isfile('/sys/fs/cgroup/memory/memory.usage_in_bytes'):
cgroup_path_memory_used = '/sys/fs/cgroup/memory/memory.usage_in_bytes'
def get_free_memory(dev=None, torch_free_too=False):
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_used = -1
if cgroup_path_memory_used is not None:
with open(cgroup_path_memory_used, 'r') as f:
raw = f.read()
if raw.isdigit():
mem_used = int(f.read())
if mem_used > 0:
mem_free_total = get_total_memory(dev) - mem_used
else:
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else: