From 397c216b14d9919c32ad3482f903b28fba01e7d6 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:22:44 +0200 Subject: [PATCH] Add disk weight materialization logging --- comfy/disk_weights.py | 71 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 588e55518..0bf044df7 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -19,6 +19,7 @@ from __future__ import annotations import collections +import logging import weakref from dataclasses import dataclass, field from types import SimpleNamespace @@ -159,6 +160,7 @@ class DiskWeightCache: REGISTRY = DiskWeightRegistry() CACHE = DiskWeightCache(0) +LOGGER = logging.getLogger(__name__) def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True): @@ -285,6 +287,59 @@ def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, Disk _update_disk_state_attrs(module, state) +def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]): + cpu_bytes = 0 + gpu_bytes = 0 + meta_bytes = 0 + total_bytes = 0 + for name, ref in refs.items(): + tensor = None + if name in module._parameters: + tensor = module._parameters[name] + elif name in module._buffers: + tensor = module._buffers[name] + if tensor is None: + continue + nbytes = _meta_nbytes(ref.meta) + if nbytes is None: + nbytes = _tensor_nbytes(tensor) + total_bytes += nbytes + if tensor.device.type == "meta": + meta_bytes += nbytes + elif tensor.device.type == "cpu": + cpu_bytes += nbytes + else: + gpu_bytes += nbytes + return total_bytes, cpu_bytes, gpu_bytes, meta_bytes + + +def _log_materialization( + module: torch.nn.Module, + target_device: torch.device, + free_mem: int, + refs: Dict[str, DiskTensorRef], + state: DiskMaterializationState, + context: str, +): + total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs) + partial = meta_bytes > 0 + LOGGER.info( + "%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s " + "loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s", + context, + module.__class__.__name__, + target_device, + total_bytes / (1024 * 1024), + free_mem / (1024 * 1024), + partial, + state.loaded_bytes / (1024 * 1024), + state.deferred_bytes / (1024 * 1024), + cpu_bytes / (1024 * 1024), + gpu_bytes / (1024 * 1024), + not partial, + ) + + def _device_free_memory(device: torch.device) -> int: from . import model_management return int(model_management.get_free_memory(device)) @@ -564,7 +619,8 @@ def ensure_module_materialized( return state = _get_materialization_state(module) _rebuild_materialization_state(module, refs, state) - remaining_budget = _device_free_memory(target_device) + free_mem_start = _device_free_memory(target_device) + remaining_budget = free_mem_start for name in sorted(refs.keys()): disk_ref = refs[name] if name in module._parameters: @@ -611,6 +667,7 @@ def ensure_module_materialized( CACHE.record(module, name, tensor, is_buffer=is_buffer) remaining_budget = max(0, remaining_budget - required_bytes) _rebuild_materialization_state(module, refs, state) + _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized") def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): @@ -716,6 +773,7 @@ def load_module_tensor( required_bytes = _meta_nbytes(disk_ref.meta) if required_bytes is None: return current + free_mem_start = _device_free_memory(device) free_mem = _maybe_free_ram_budget(device, required_bytes) load_device = device if free_mem < required_bytes and allow_alternate: @@ -730,6 +788,7 @@ def load_module_tensor( state.deferred_keys.add(name) state.deferred_bytes += required_bytes _update_disk_state_attrs(module, state) + _log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred") return current else: state = _get_materialization_state(module) @@ -737,6 +796,7 @@ def load_module_tensor( state.deferred_keys.add(name) state.deferred_bytes += required_bytes _update_disk_state_attrs(module, state) + _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") return current elif free_mem < required_bytes: state = _get_materialization_state(module) @@ -744,6 +804,7 @@ def load_module_tensor( state.deferred_keys.add(name) state.deferred_bytes += required_bytes _update_disk_state_attrs(module, state) + _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") return current tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU) @@ -755,7 +816,9 @@ def load_module_tensor( module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) if tensor.device.type == "cpu" and record_cache: CACHE.record(module, name, tensor, is_buffer=is_buffer) - _rebuild_materialization_state(module, refs, _get_materialization_state(module)) + state = _get_materialization_state(module) + _rebuild_materialization_state(module, refs, state) + _log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded") return tensor @@ -790,7 +853,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz key = f"{lazy_state.prefix}{name}" if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta": existing[key] = buf - remaining_budget = _device_free_memory(target_device) + free_mem_start = _device_free_memory(target_device) + remaining_budget = free_mem_start allowed = set(existing.keys()) for key in keys: if key in allowed: @@ -840,6 +904,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs))) _rebuild_materialization_state(module, refs, state) lazy_state.loaded = len(deferred_state_dict_keys) == 0 + _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed") for name, param in module.named_parameters(recurse=False): if param.device.type == "cpu": CACHE.record(module, name, param, is_buffer=False)