mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
Add disk weight materialization logging
This commit is contained in:
parent
010d5445fe
commit
397c216b14
@ -19,6 +19,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from types import SimpleNamespace
|
||||
@ -159,6 +160,7 @@ class DiskWeightCache:
|
||||
|
||||
REGISTRY = DiskWeightRegistry()
|
||||
CACHE = DiskWeightCache(0)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
||||
@ -285,6 +287,59 @@ def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, Disk
|
||||
_update_disk_state_attrs(module, state)
|
||||
|
||||
|
||||
def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]):
|
||||
cpu_bytes = 0
|
||||
gpu_bytes = 0
|
||||
meta_bytes = 0
|
||||
total_bytes = 0
|
||||
for name, ref in refs.items():
|
||||
tensor = None
|
||||
if name in module._parameters:
|
||||
tensor = module._parameters[name]
|
||||
elif name in module._buffers:
|
||||
tensor = module._buffers[name]
|
||||
if tensor is None:
|
||||
continue
|
||||
nbytes = _meta_nbytes(ref.meta)
|
||||
if nbytes is None:
|
||||
nbytes = _tensor_nbytes(tensor)
|
||||
total_bytes += nbytes
|
||||
if tensor.device.type == "meta":
|
||||
meta_bytes += nbytes
|
||||
elif tensor.device.type == "cpu":
|
||||
cpu_bytes += nbytes
|
||||
else:
|
||||
gpu_bytes += nbytes
|
||||
return total_bytes, cpu_bytes, gpu_bytes, meta_bytes
|
||||
|
||||
|
||||
def _log_materialization(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
free_mem: int,
|
||||
refs: Dict[str, DiskTensorRef],
|
||||
state: DiskMaterializationState,
|
||||
context: str,
|
||||
):
|
||||
total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs)
|
||||
partial = meta_bytes > 0
|
||||
LOGGER.info(
|
||||
"%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s "
|
||||
"loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s",
|
||||
context,
|
||||
module.__class__.__name__,
|
||||
target_device,
|
||||
total_bytes / (1024 * 1024),
|
||||
free_mem / (1024 * 1024),
|
||||
partial,
|
||||
state.loaded_bytes / (1024 * 1024),
|
||||
state.deferred_bytes / (1024 * 1024),
|
||||
cpu_bytes / (1024 * 1024),
|
||||
gpu_bytes / (1024 * 1024),
|
||||
not partial,
|
||||
)
|
||||
|
||||
|
||||
def _device_free_memory(device: torch.device) -> int:
|
||||
from . import model_management
|
||||
return int(model_management.get_free_memory(device))
|
||||
@ -564,7 +619,8 @@ def ensure_module_materialized(
|
||||
return
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
remaining_budget = _device_free_memory(target_device)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
for name in sorted(refs.keys()):
|
||||
disk_ref = refs[name]
|
||||
if name in module._parameters:
|
||||
@ -611,6 +667,7 @@ def ensure_module_materialized(
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
remaining_budget = max(0, remaining_budget - required_bytes)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||
|
||||
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
@ -716,6 +773,7 @@ def load_module_tensor(
|
||||
required_bytes = _meta_nbytes(disk_ref.meta)
|
||||
if required_bytes is None:
|
||||
return current
|
||||
free_mem_start = _device_free_memory(device)
|
||||
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
||||
load_device = device
|
||||
if free_mem < required_bytes and allow_alternate:
|
||||
@ -730,6 +788,7 @@ def load_module_tensor(
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
|
||||
return current
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
@ -737,6 +796,7 @@ def load_module_tensor(
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
elif free_mem < required_bytes:
|
||||
state = _get_materialization_state(module)
|
||||
@ -744,6 +804,7 @@ def load_module_tensor(
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
|
||||
@ -755,7 +816,9 @@ def load_module_tensor(
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu" and record_cache:
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
|
||||
return tensor
|
||||
|
||||
|
||||
@ -790,7 +853,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
||||
key = f"{lazy_state.prefix}{name}"
|
||||
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
||||
existing[key] = buf
|
||||
remaining_budget = _device_free_memory(target_device)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
allowed = set(existing.keys())
|
||||
for key in keys:
|
||||
if key in allowed:
|
||||
@ -840,6 +904,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device.type == "cpu":
|
||||
CACHE.record(module, name, param, is_buffer=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user