Add disk weight materialization logging

This commit is contained in:
ifilipis 2026-01-08 19:22:44 +02:00
parent 010d5445fe
commit 397c216b14

View File

@ -19,6 +19,7 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
import logging
import weakref import weakref
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import SimpleNamespace from types import SimpleNamespace
@ -159,6 +160,7 @@ class DiskWeightCache:
REGISTRY = DiskWeightRegistry() REGISTRY = DiskWeightRegistry()
CACHE = DiskWeightCache(0) CACHE = DiskWeightCache(0)
LOGGER = logging.getLogger(__name__)
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True): def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
@ -285,6 +287,59 @@ def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, Disk
_update_disk_state_attrs(module, state) _update_disk_state_attrs(module, state)
def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]):
cpu_bytes = 0
gpu_bytes = 0
meta_bytes = 0
total_bytes = 0
for name, ref in refs.items():
tensor = None
if name in module._parameters:
tensor = module._parameters[name]
elif name in module._buffers:
tensor = module._buffers[name]
if tensor is None:
continue
nbytes = _meta_nbytes(ref.meta)
if nbytes is None:
nbytes = _tensor_nbytes(tensor)
total_bytes += nbytes
if tensor.device.type == "meta":
meta_bytes += nbytes
elif tensor.device.type == "cpu":
cpu_bytes += nbytes
else:
gpu_bytes += nbytes
return total_bytes, cpu_bytes, gpu_bytes, meta_bytes
def _log_materialization(
module: torch.nn.Module,
target_device: torch.device,
free_mem: int,
refs: Dict[str, DiskTensorRef],
state: DiskMaterializationState,
context: str,
):
total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs)
partial = meta_bytes > 0
LOGGER.info(
"%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s "
"loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s",
context,
module.__class__.__name__,
target_device,
total_bytes / (1024 * 1024),
free_mem / (1024 * 1024),
partial,
state.loaded_bytes / (1024 * 1024),
state.deferred_bytes / (1024 * 1024),
cpu_bytes / (1024 * 1024),
gpu_bytes / (1024 * 1024),
not partial,
)
def _device_free_memory(device: torch.device) -> int: def _device_free_memory(device: torch.device) -> int:
from . import model_management from . import model_management
return int(model_management.get_free_memory(device)) return int(model_management.get_free_memory(device))
@ -564,7 +619,8 @@ def ensure_module_materialized(
return return
state = _get_materialization_state(module) state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state) _rebuild_materialization_state(module, refs, state)
remaining_budget = _device_free_memory(target_device) free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
for name in sorted(refs.keys()): for name in sorted(refs.keys()):
disk_ref = refs[name] disk_ref = refs[name]
if name in module._parameters: if name in module._parameters:
@ -611,6 +667,7 @@ def ensure_module_materialized(
CACHE.record(module, name, tensor, is_buffer=is_buffer) CACHE.record(module, name, tensor, is_buffer=is_buffer)
remaining_budget = max(0, remaining_budget - required_bytes) remaining_budget = max(0, remaining_budget - required_bytes)
_rebuild_materialization_state(module, refs, state) _rebuild_materialization_state(module, refs, state)
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
@ -716,6 +773,7 @@ def load_module_tensor(
required_bytes = _meta_nbytes(disk_ref.meta) required_bytes = _meta_nbytes(disk_ref.meta)
if required_bytes is None: if required_bytes is None:
return current return current
free_mem_start = _device_free_memory(device)
free_mem = _maybe_free_ram_budget(device, required_bytes) free_mem = _maybe_free_ram_budget(device, required_bytes)
load_device = device load_device = device
if free_mem < required_bytes and allow_alternate: if free_mem < required_bytes and allow_alternate:
@ -730,6 +788,7 @@ def load_module_tensor(
state.deferred_keys.add(name) state.deferred_keys.add(name)
state.deferred_bytes += required_bytes state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state) _update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
return current return current
else: else:
state = _get_materialization_state(module) state = _get_materialization_state(module)
@ -737,6 +796,7 @@ def load_module_tensor(
state.deferred_keys.add(name) state.deferred_keys.add(name)
state.deferred_bytes += required_bytes state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state) _update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current return current
elif free_mem < required_bytes: elif free_mem < required_bytes:
state = _get_materialization_state(module) state = _get_materialization_state(module)
@ -744,6 +804,7 @@ def load_module_tensor(
state.deferred_keys.add(name) state.deferred_keys.add(name)
state.deferred_bytes += required_bytes state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state) _update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current return current
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU) tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
@ -755,7 +816,9 @@ def load_module_tensor(
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu" and record_cache: if tensor.device.type == "cpu" and record_cache:
CACHE.record(module, name, tensor, is_buffer=is_buffer) CACHE.record(module, name, tensor, is_buffer=is_buffer)
_rebuild_materialization_state(module, refs, _get_materialization_state(module)) state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
return tensor return tensor
@ -790,7 +853,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
key = f"{lazy_state.prefix}{name}" key = f"{lazy_state.prefix}{name}"
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta": if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
existing[key] = buf existing[key] = buf
remaining_budget = _device_free_memory(target_device) free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
allowed = set(existing.keys()) allowed = set(existing.keys())
for key in keys: for key in keys:
if key in allowed: if key in allowed:
@ -840,6 +904,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs))) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
_rebuild_materialization_state(module, refs, state) _rebuild_materialization_state(module, refs, state)
lazy_state.loaded = len(deferred_state_dict_keys) == 0 lazy_state.loaded = len(deferred_state_dict_keys) == 0
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu": if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False) CACHE.record(module, name, param, is_buffer=False)