Add disk weight materialization logging

This commit is contained in:
ifilipis 2026-01-08 19:22:44 +02:00
parent 010d5445fe
commit 397c216b14

View File

@ -19,6 +19,7 @@
from __future__ import annotations
import collections
import logging
import weakref
from dataclasses import dataclass, field
from types import SimpleNamespace
@ -159,6 +160,7 @@ class DiskWeightCache:
REGISTRY = DiskWeightRegistry()
CACHE = DiskWeightCache(0)
LOGGER = logging.getLogger(__name__)
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
@ -285,6 +287,59 @@ def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, Disk
_update_disk_state_attrs(module, state)
def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]):
cpu_bytes = 0
gpu_bytes = 0
meta_bytes = 0
total_bytes = 0
for name, ref in refs.items():
tensor = None
if name in module._parameters:
tensor = module._parameters[name]
elif name in module._buffers:
tensor = module._buffers[name]
if tensor is None:
continue
nbytes = _meta_nbytes(ref.meta)
if nbytes is None:
nbytes = _tensor_nbytes(tensor)
total_bytes += nbytes
if tensor.device.type == "meta":
meta_bytes += nbytes
elif tensor.device.type == "cpu":
cpu_bytes += nbytes
else:
gpu_bytes += nbytes
return total_bytes, cpu_bytes, gpu_bytes, meta_bytes
def _log_materialization(
module: torch.nn.Module,
target_device: torch.device,
free_mem: int,
refs: Dict[str, DiskTensorRef],
state: DiskMaterializationState,
context: str,
):
total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs)
partial = meta_bytes > 0
LOGGER.info(
"%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s "
"loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s",
context,
module.__class__.__name__,
target_device,
total_bytes / (1024 * 1024),
free_mem / (1024 * 1024),
partial,
state.loaded_bytes / (1024 * 1024),
state.deferred_bytes / (1024 * 1024),
cpu_bytes / (1024 * 1024),
gpu_bytes / (1024 * 1024),
not partial,
)
def _device_free_memory(device: torch.device) -> int:
from . import model_management
return int(model_management.get_free_memory(device))
@ -564,7 +619,8 @@ def ensure_module_materialized(
return
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
remaining_budget = _device_free_memory(target_device)
free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
for name in sorted(refs.keys()):
disk_ref = refs[name]
if name in module._parameters:
@ -611,6 +667,7 @@ def ensure_module_materialized(
CACHE.record(module, name, tensor, is_buffer=is_buffer)
remaining_budget = max(0, remaining_budget - required_bytes)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
@ -716,6 +773,7 @@ def load_module_tensor(
required_bytes = _meta_nbytes(disk_ref.meta)
if required_bytes is None:
return current
free_mem_start = _device_free_memory(device)
free_mem = _maybe_free_ram_budget(device, required_bytes)
load_device = device
if free_mem < required_bytes and allow_alternate:
@ -730,6 +788,7 @@ def load_module_tensor(
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
return current
else:
state = _get_materialization_state(module)
@ -737,6 +796,7 @@ def load_module_tensor(
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
elif free_mem < required_bytes:
state = _get_materialization_state(module)
@ -744,6 +804,7 @@ def load_module_tensor(
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
@ -755,7 +816,9 @@ def load_module_tensor(
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu" and record_cache:
CACHE.record(module, name, tensor, is_buffer=is_buffer)
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
return tensor
@ -790,7 +853,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
key = f"{lazy_state.prefix}{name}"
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
existing[key] = buf
remaining_budget = _device_free_memory(target_device)
free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
allowed = set(existing.keys())
for key in keys:
if key in allowed:
@ -840,6 +904,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
_rebuild_materialization_state(module, refs, state)
lazy_state.loaded = len(deferred_state_dict_keys) == 0
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False)