Enable lazy disk-tier loading for streaming weights

This commit is contained in:
ifilipis 2026-01-08 18:40:29 +02:00
parent f925f8fa77
commit 5f2188e31b
8 changed files with 320 additions and 31 deletions

View File

@ -27,31 +27,31 @@
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
## Strategy summary (no coding performed yet)
## Strategy summary (implemented)
### Streaming safetensors mapping (no full dict materialization)
- Introduce a new module `comfy/safetensors_stream.py` with:
- `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
- `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
- Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
- [x] Introduce a new module `comfy/safetensors_stream.py` with:
- [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
- [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
- [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
### Range reads and tiering
- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
- [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
- [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
- [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
### Disk tier integration
- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
### Pipeline refactors
- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
- [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
### Tests and docs
- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
- [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.

View File

@ -25,6 +25,7 @@ import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.disk_weights
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
@ -385,7 +386,7 @@ class ControlLora(ControlNet):
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model.to(comfy.model_management.get_torch_device())
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
@ -816,8 +817,8 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
self.t2i_model.to(x_noisy.dtype)
self.t2i_model.to(self.device)
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
comfy.disk_weights.module_to(self.t2i_model, self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()

View File

@ -21,14 +21,18 @@ from __future__ import annotations
import collections
import weakref
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, MutableMapping, Optional
import torch
from . import safetensors_stream
ALLOW_GDS = False
PIN_IF_CPU = False
DISK_WEIGHTS_ENABLED = False
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
@dataclass
@ -123,6 +127,15 @@ class DiskWeightCache:
_evict_module_weight(module, entry.name, entry.is_buffer)
return freed
def remove_module(self, module: torch.nn.Module):
to_remove = []
for key, entry in self._entries.items():
if entry.module_ref() is module:
to_remove.append(key)
for key in to_remove:
entry = self._entries.pop(key)
self.current_bytes -= entry.size_bytes
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
to_remove = []
for key, entry in self._entries.items():
@ -183,7 +196,61 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
CACHE.record(module, name, buf, is_buffer=True)
@dataclass
class LazyModuleState:
state_dict: MutableMapping
prefix: str
loaded: bool = False
def _has_custom_load(module: torch.nn.Module) -> bool:
return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT
def register_lazy_modules(model: torch.nn.Module, state_dict):
if not hasattr(state_dict, "keys"):
return
for name, module in model.named_modules():
if not _has_custom_load(module):
continue
prefix = f"{name}." if name else ""
if prefix:
has_key = False
for param_name in module._parameters.keys():
if f"{prefix}{param_name}" in state_dict:
has_key = True
break
if not has_key:
for buf_name in module._buffers.keys():
if f"{prefix}{buf_name}" in state_dict:
has_key = True
break
if not has_key:
continue
view = safetensors_stream.FilterViewStateDict(
state_dict, lambda k, p=prefix: k.startswith(p), mutate_base=False
)
LAZY_MODULE_STATE[module] = LazyModuleState(state_dict=view, prefix=prefix)
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None:
CACHE.remove_module(module)
refs = REGISTRY.get(module)
if refs:
for ref_name, disk_ref in refs.items():
shape = getattr(disk_ref.meta, "shape", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
continue
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
if disk_ref.is_buffer:
module._buffers[ref_name] = meta_tensor
else:
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
lazy_state.loaded = False
return
ref = REGISTRY.get(module)
if not ref or name not in ref:
return
@ -222,6 +289,10 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None and not lazy_state.loaded:
_materialize_module_from_state_dict(module, lazy_state, target_device)
return
refs = REGISTRY.get(module)
if not refs:
return
@ -236,11 +307,14 @@ def ensure_module_materialized(module: torch.nn.Module, target_device: torch.dev
continue
if current is None:
continue
if current.device.type != "meta":
if current.device.type == "meta":
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
elif current.device != target_device:
tensor = current.to(device=target_device)
else:
if current.device.type == "cpu":
CACHE.touch(module, name)
continue
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
if is_buffer:
module._buffers[name] = tensor
else:
@ -273,3 +347,138 @@ def evict_ram_cache(bytes_to_free: int):
if bytes_to_free <= 0:
return 0
return CACHE.evict_bytes(bytes_to_free)
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
if not disk_weights_enabled():
return
for submodule in module.modules():
ensure_module_materialized(submodule, target_device)
def _extract_to_device(args, kwargs) -> Optional[torch.device]:
if "device" in kwargs and kwargs["device"] is not None:
return torch.device(kwargs["device"])
for arg in args:
if isinstance(arg, torch.device):
return arg
if isinstance(arg, str):
return torch.device(arg)
return None
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
for param in module.parameters(recurse=True):
if param is not None and param.device.type != "meta":
return param.device
for buf in module.buffers(recurse=True):
if buf is not None and buf.device.type != "meta":
return buf.device
return None
def module_to(module: torch.nn.Module, *args, **kwargs):
if disk_weights_enabled():
target_device = _extract_to_device(args, kwargs)
if target_device is None:
target_device = _find_existing_device(module) or torch.device("cpu")
materialize_module_tree(module, target_device)
return module.to(*args, **kwargs)
def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool):
parts = name.split(".")
module = model
for part in parts[:-1]:
module = getattr(module, part)
attr = parts[-1]
if is_buffer:
module._buffers[attr] = tensor
else:
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: LazyModuleState, target_device: torch.device):
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(lazy_state.state_dict, "_metadata", None)
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
state_dict = safetensors_stream.DeviceViewStateDict(
lazy_state.state_dict,
device=target_device,
allow_gds=ALLOW_GDS,
pin_if_cpu=PIN_IF_CPU,
mutate_base=False,
)
factory_device = None
if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs:
factory_device = module.factory_kwargs["device"]
module.factory_kwargs["device"] = target_device
try:
module._load_from_state_dict(
state_dict,
lazy_state.prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
finally:
if factory_device is not None:
module.factory_kwargs["device"] = factory_device
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
lazy_state.loaded = True
for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False)
for name, buf in module.named_buffers(recurse=False):
if buf is not None and buf.device.type == "cpu":
CACHE.record(module, name, buf, is_buffer=True)
def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = False):
model_keys = set()
for name, _ in model.named_parameters(recurse=True):
model_keys.add(name)
for name, _ in model.named_buffers(recurse=True):
model_keys.add(name)
state_keys = set(state_dict.keys())
missing_keys = [k for k in model_keys if k not in state_keys]
unexpected_keys = [k for k in state_keys if k not in model_keys]
if strict:
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append('Unexpected key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.append('Missing key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in missing_keys)))
if error_msgs:
raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs)))
for name, param in model.named_parameters(recurse=True):
if name not in state_keys:
continue
meta = state_dict.meta(name)
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
for name, buf in model.named_buffers(recurse=True):
if buf is None or name not in state_keys:
continue
meta = state_dict.meta(name)
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
register_module_weights(model, state_dict)
register_lazy_modules(model, state_dict)
attach_disk_weight_hooks(model)
return missing_keys, unexpected_keys

View File

@ -785,7 +785,7 @@ class ModelPatcher:
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
comfy.disk_weights.module_to(x[2], device_to)
for x in offloaded:
n = x[1]
@ -800,7 +800,7 @@ class ModelPatcher:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
comfy.disk_weights.module_to(self.model, device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
@ -857,7 +857,7 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
comfy.disk_weights.module_to(self.model, device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0

View File

@ -442,6 +442,12 @@ class StreamStateDict(collections.abc.MutableMapping):
raise KeyError(key)
if device is None:
device = self._device
if device.type == "meta":
meta = self._index.meta(key)
target_dtype = dtype or meta.dtype
if dtype is not None and dtype != meta.dtype:
_validate_dtype_conversion(meta.dtype, dtype)
return torch.empty(meta.shape, dtype=target_dtype, device="meta")
if allow_gds is None:
allow_gds = self._allow_gds
meta = self._index.meta(key)
@ -559,6 +565,37 @@ class _BaseViewStateDict(MutableMapping):
t = t.to(dtype=dtype)
return t
class DeviceViewStateDict(_BaseViewStateDict):
def __init__(
self,
base: MutableMapping,
device: torch.device,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
mutate_base: bool = False,
):
super().__init__(base, mutate_base=mutate_base)
self._device = device
self._allow_gds = allow_gds
self._pin_if_cpu = pin_if_cpu
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
device = self._device if device is None else device
allow_gds = self._allow_gds if allow_gds is None else allow_gds
pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
return super().get_tensor(
key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
)
def meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]

View File

@ -26,6 +26,7 @@ import os
import comfy.utils
import comfy.safetensors_stream
import comfy.disk_weights
from . import clip_vision
from . import gligen
@ -125,7 +126,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
self.cond_stage_model.to(offload_device)
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
@ -671,7 +672,7 @@ class VAE:
if dtype is None:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -1546,7 +1547,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model = comfy.disk_weights.module_to(model, offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:

View File

@ -168,6 +168,8 @@ def state_dict_meta(state_dict, key):
def load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict):
if comfy.disk_weights.disk_weights_enabled():
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
comfy.disk_weights.register_module_weights(model, state_dict)
comfy.disk_weights.attach_disk_weight_hooks(model)
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
@ -900,7 +902,10 @@ def copy_to_param(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
prev.data.copy_(value)
if prev.device.type == "meta":
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad))
else:
prev.data.copy_(value)
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.

View File

@ -144,3 +144,39 @@ def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
assert model.weight.device.type != "meta"
finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
if importlib.util.find_spec("fastsafetensors") is None:
pytest.skip("fastsafetensors not installed")
import comfy.utils
import comfy.disk_weights
prev_cache = comfy.disk_weights.CACHE.max_bytes
prev_gds = comfy.disk_weights.ALLOW_GDS
prev_pin = comfy.disk_weights.PIN_IF_CPU
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
try:
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = torch.nn.Linear(4, 4, bias=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
comfy.utils.load_state_dict(model, sd, strict=True)
assert model.weight.device.type == "meta"
assert calls == []
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
assert model.weight.device.type == "cpu"
assert len(calls) == 2
finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)