mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 14:50:20 +08:00
Add disk weight materialization logging
This commit is contained in:
parent
010d5445fe
commit
397c216b14
@ -19,6 +19,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import logging
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@ -159,6 +160,7 @@ class DiskWeightCache:
|
|||||||
|
|
||||||
REGISTRY = DiskWeightRegistry()
|
REGISTRY = DiskWeightRegistry()
|
||||||
CACHE = DiskWeightCache(0)
|
CACHE = DiskWeightCache(0)
|
||||||
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
||||||
@ -285,6 +287,59 @@ def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, Disk
|
|||||||
_update_disk_state_attrs(module, state)
|
_update_disk_state_attrs(module, state)
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]):
|
||||||
|
cpu_bytes = 0
|
||||||
|
gpu_bytes = 0
|
||||||
|
meta_bytes = 0
|
||||||
|
total_bytes = 0
|
||||||
|
for name, ref in refs.items():
|
||||||
|
tensor = None
|
||||||
|
if name in module._parameters:
|
||||||
|
tensor = module._parameters[name]
|
||||||
|
elif name in module._buffers:
|
||||||
|
tensor = module._buffers[name]
|
||||||
|
if tensor is None:
|
||||||
|
continue
|
||||||
|
nbytes = _meta_nbytes(ref.meta)
|
||||||
|
if nbytes is None:
|
||||||
|
nbytes = _tensor_nbytes(tensor)
|
||||||
|
total_bytes += nbytes
|
||||||
|
if tensor.device.type == "meta":
|
||||||
|
meta_bytes += nbytes
|
||||||
|
elif tensor.device.type == "cpu":
|
||||||
|
cpu_bytes += nbytes
|
||||||
|
else:
|
||||||
|
gpu_bytes += nbytes
|
||||||
|
return total_bytes, cpu_bytes, gpu_bytes, meta_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def _log_materialization(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
target_device: torch.device,
|
||||||
|
free_mem: int,
|
||||||
|
refs: Dict[str, DiskTensorRef],
|
||||||
|
state: DiskMaterializationState,
|
||||||
|
context: str,
|
||||||
|
):
|
||||||
|
total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs)
|
||||||
|
partial = meta_bytes > 0
|
||||||
|
LOGGER.info(
|
||||||
|
"%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s "
|
||||||
|
"loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s",
|
||||||
|
context,
|
||||||
|
module.__class__.__name__,
|
||||||
|
target_device,
|
||||||
|
total_bytes / (1024 * 1024),
|
||||||
|
free_mem / (1024 * 1024),
|
||||||
|
partial,
|
||||||
|
state.loaded_bytes / (1024 * 1024),
|
||||||
|
state.deferred_bytes / (1024 * 1024),
|
||||||
|
cpu_bytes / (1024 * 1024),
|
||||||
|
gpu_bytes / (1024 * 1024),
|
||||||
|
not partial,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _device_free_memory(device: torch.device) -> int:
|
def _device_free_memory(device: torch.device) -> int:
|
||||||
from . import model_management
|
from . import model_management
|
||||||
return int(model_management.get_free_memory(device))
|
return int(model_management.get_free_memory(device))
|
||||||
@ -564,7 +619,8 @@ def ensure_module_materialized(
|
|||||||
return
|
return
|
||||||
state = _get_materialization_state(module)
|
state = _get_materialization_state(module)
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
remaining_budget = _device_free_memory(target_device)
|
free_mem_start = _device_free_memory(target_device)
|
||||||
|
remaining_budget = free_mem_start
|
||||||
for name in sorted(refs.keys()):
|
for name in sorted(refs.keys()):
|
||||||
disk_ref = refs[name]
|
disk_ref = refs[name]
|
||||||
if name in module._parameters:
|
if name in module._parameters:
|
||||||
@ -611,6 +667,7 @@ def ensure_module_materialized(
|
|||||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||||
remaining_budget = max(0, remaining_budget - required_bytes)
|
remaining_budget = max(0, remaining_budget - required_bytes)
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
|
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||||
|
|
||||||
|
|
||||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||||
@ -716,6 +773,7 @@ def load_module_tensor(
|
|||||||
required_bytes = _meta_nbytes(disk_ref.meta)
|
required_bytes = _meta_nbytes(disk_ref.meta)
|
||||||
if required_bytes is None:
|
if required_bytes is None:
|
||||||
return current
|
return current
|
||||||
|
free_mem_start = _device_free_memory(device)
|
||||||
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
||||||
load_device = device
|
load_device = device
|
||||||
if free_mem < required_bytes and allow_alternate:
|
if free_mem < required_bytes and allow_alternate:
|
||||||
@ -730,6 +788,7 @@ def load_module_tensor(
|
|||||||
state.deferred_keys.add(name)
|
state.deferred_keys.add(name)
|
||||||
state.deferred_bytes += required_bytes
|
state.deferred_bytes += required_bytes
|
||||||
_update_disk_state_attrs(module, state)
|
_update_disk_state_attrs(module, state)
|
||||||
|
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
|
||||||
return current
|
return current
|
||||||
else:
|
else:
|
||||||
state = _get_materialization_state(module)
|
state = _get_materialization_state(module)
|
||||||
@ -737,6 +796,7 @@ def load_module_tensor(
|
|||||||
state.deferred_keys.add(name)
|
state.deferred_keys.add(name)
|
||||||
state.deferred_bytes += required_bytes
|
state.deferred_bytes += required_bytes
|
||||||
_update_disk_state_attrs(module, state)
|
_update_disk_state_attrs(module, state)
|
||||||
|
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||||
return current
|
return current
|
||||||
elif free_mem < required_bytes:
|
elif free_mem < required_bytes:
|
||||||
state = _get_materialization_state(module)
|
state = _get_materialization_state(module)
|
||||||
@ -744,6 +804,7 @@ def load_module_tensor(
|
|||||||
state.deferred_keys.add(name)
|
state.deferred_keys.add(name)
|
||||||
state.deferred_bytes += required_bytes
|
state.deferred_bytes += required_bytes
|
||||||
_update_disk_state_attrs(module, state)
|
_update_disk_state_attrs(module, state)
|
||||||
|
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||||
return current
|
return current
|
||||||
|
|
||||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
|
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
|
||||||
@ -755,7 +816,9 @@ def load_module_tensor(
|
|||||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||||
if tensor.device.type == "cpu" and record_cache:
|
if tensor.device.type == "cpu" and record_cache:
|
||||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||||
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
|
state = _get_materialization_state(module)
|
||||||
|
_rebuild_materialization_state(module, refs, state)
|
||||||
|
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
@ -790,7 +853,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
|||||||
key = f"{lazy_state.prefix}{name}"
|
key = f"{lazy_state.prefix}{name}"
|
||||||
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
||||||
existing[key] = buf
|
existing[key] = buf
|
||||||
remaining_budget = _device_free_memory(target_device)
|
free_mem_start = _device_free_memory(target_device)
|
||||||
|
remaining_budget = free_mem_start
|
||||||
allowed = set(existing.keys())
|
allowed = set(existing.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key in allowed:
|
if key in allowed:
|
||||||
@ -840,6 +904,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
|||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
||||||
|
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
||||||
for name, param in module.named_parameters(recurse=False):
|
for name, param in module.named_parameters(recurse=False):
|
||||||
if param.device.type == "cpu":
|
if param.device.type == "cpu":
|
||||||
CACHE.record(module, name, param, is_buffer=False)
|
CACHE.record(module, name, param, is_buffer=False)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user