Enable lazy disk-tier loading for streaming weights

This commit is contained in:
ifilipis 2026-01-08 18:40:29 +02:00
parent f925f8fa77
commit 5f2188e31b
8 changed files with 320 additions and 31 deletions

View File

@ -27,31 +27,31 @@
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】 - `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】 - `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
## Strategy summary (no coding performed yet) ## Strategy summary (implemented)
### Streaming safetensors mapping (no full dict materialization) ### Streaming safetensors mapping (no full dict materialization)
- Introduce a new module `comfy/safetensors_stream.py` with: - [x] Introduce a new module `comfy/safetensors_stream.py` with:
- `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`). - [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
- `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand. - [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
- Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading. - [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
### Range reads and tiering ### Range reads and tiering
- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack. - [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS. - [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it. - [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
### Disk tier integration ### Disk tier integration
- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`. - [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload. - [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`. - [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
### Pipeline refactors ### Pipeline refactors
- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading. - [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy. - [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead. - [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads). - [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader. - [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
### Tests and docs ### Tests and docs
- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path. - [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported. - [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.

View File

@ -25,6 +25,7 @@ import logging
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
import comfy.disk_weights
import comfy.model_patcher import comfy.model_patcher
import comfy.ops import comfy.ops
import comfy.latent_formats import comfy.latent_formats
@ -385,7 +386,7 @@ class ControlLora(ControlNet):
controlnet_config["operations"] = control_lora_ops controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model.to(comfy.model_management.get_torch_device()) comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict() sd = diffusion_model.state_dict()
@ -816,8 +817,8 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None: if self.control_input is None:
self.t2i_model.to(x_noisy.dtype) comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
self.t2i_model.to(self.device) comfy.disk_weights.module_to(self.t2i_model, self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu() self.t2i_model.cpu()

View File

@ -21,14 +21,18 @@ from __future__ import annotations
import collections import collections
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, MutableMapping, Optional
import torch import torch
from . import safetensors_stream
ALLOW_GDS = False ALLOW_GDS = False
PIN_IF_CPU = False PIN_IF_CPU = False
DISK_WEIGHTS_ENABLED = False DISK_WEIGHTS_ENABLED = False
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
@dataclass @dataclass
@ -123,6 +127,15 @@ class DiskWeightCache:
_evict_module_weight(module, entry.name, entry.is_buffer) _evict_module_weight(module, entry.name, entry.is_buffer)
return freed return freed
def remove_module(self, module: torch.nn.Module):
to_remove = []
for key, entry in self._entries.items():
if entry.module_ref() is module:
to_remove.append(key)
for key in to_remove:
entry = self._entries.pop(key)
self.current_bytes -= entry.size_bytes
def _drop_module_entries(self, module_ref: weakref.ReferenceType): def _drop_module_entries(self, module_ref: weakref.ReferenceType):
to_remove = [] to_remove = []
for key, entry in self._entries.items(): for key, entry in self._entries.items():
@ -183,7 +196,61 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
CACHE.record(module, name, buf, is_buffer=True) CACHE.record(module, name, buf, is_buffer=True)
@dataclass
class LazyModuleState:
state_dict: MutableMapping
prefix: str
loaded: bool = False
def _has_custom_load(module: torch.nn.Module) -> bool:
return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT
def register_lazy_modules(model: torch.nn.Module, state_dict):
if not hasattr(state_dict, "keys"):
return
for name, module in model.named_modules():
if not _has_custom_load(module):
continue
prefix = f"{name}." if name else ""
if prefix:
has_key = False
for param_name in module._parameters.keys():
if f"{prefix}{param_name}" in state_dict:
has_key = True
break
if not has_key:
for buf_name in module._buffers.keys():
if f"{prefix}{buf_name}" in state_dict:
has_key = True
break
if not has_key:
continue
view = safetensors_stream.FilterViewStateDict(
state_dict, lambda k, p=prefix: k.startswith(p), mutate_base=False
)
LAZY_MODULE_STATE[module] = LazyModuleState(state_dict=view, prefix=prefix)
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None:
CACHE.remove_module(module)
refs = REGISTRY.get(module)
if refs:
for ref_name, disk_ref in refs.items():
shape = getattr(disk_ref.meta, "shape", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
continue
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
if disk_ref.is_buffer:
module._buffers[ref_name] = meta_tensor
else:
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
lazy_state.loaded = False
return
ref = REGISTRY.get(module) ref = REGISTRY.get(module)
if not ref or name not in ref: if not ref or name not in ref:
return return
@ -222,6 +289,10 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device): def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None and not lazy_state.loaded:
_materialize_module_from_state_dict(module, lazy_state, target_device)
return
refs = REGISTRY.get(module) refs = REGISTRY.get(module)
if not refs: if not refs:
return return
@ -236,11 +307,14 @@ def ensure_module_materialized(module: torch.nn.Module, target_device: torch.dev
continue continue
if current is None: if current is None:
continue continue
if current.device.type != "meta": if current.device.type == "meta":
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
elif current.device != target_device:
tensor = current.to(device=target_device)
else:
if current.device.type == "cpu": if current.device.type == "cpu":
CACHE.touch(module, name) CACHE.touch(module, name)
continue continue
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
if is_buffer: if is_buffer:
module._buffers[name] = tensor module._buffers[name] = tensor
else: else:
@ -273,3 +347,138 @@ def evict_ram_cache(bytes_to_free: int):
if bytes_to_free <= 0: if bytes_to_free <= 0:
return 0 return 0
return CACHE.evict_bytes(bytes_to_free) return CACHE.evict_bytes(bytes_to_free)
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
if not disk_weights_enabled():
return
for submodule in module.modules():
ensure_module_materialized(submodule, target_device)
def _extract_to_device(args, kwargs) -> Optional[torch.device]:
if "device" in kwargs and kwargs["device"] is not None:
return torch.device(kwargs["device"])
for arg in args:
if isinstance(arg, torch.device):
return arg
if isinstance(arg, str):
return torch.device(arg)
return None
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
for param in module.parameters(recurse=True):
if param is not None and param.device.type != "meta":
return param.device
for buf in module.buffers(recurse=True):
if buf is not None and buf.device.type != "meta":
return buf.device
return None
def module_to(module: torch.nn.Module, *args, **kwargs):
if disk_weights_enabled():
target_device = _extract_to_device(args, kwargs)
if target_device is None:
target_device = _find_existing_device(module) or torch.device("cpu")
materialize_module_tree(module, target_device)
return module.to(*args, **kwargs)
def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool):
parts = name.split(".")
module = model
for part in parts[:-1]:
module = getattr(module, part)
attr = parts[-1]
if is_buffer:
module._buffers[attr] = tensor
else:
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: LazyModuleState, target_device: torch.device):
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(lazy_state.state_dict, "_metadata", None)
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
state_dict = safetensors_stream.DeviceViewStateDict(
lazy_state.state_dict,
device=target_device,
allow_gds=ALLOW_GDS,
pin_if_cpu=PIN_IF_CPU,
mutate_base=False,
)
factory_device = None
if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs:
factory_device = module.factory_kwargs["device"]
module.factory_kwargs["device"] = target_device
try:
module._load_from_state_dict(
state_dict,
lazy_state.prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
finally:
if factory_device is not None:
module.factory_kwargs["device"] = factory_device
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
lazy_state.loaded = True
for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False)
for name, buf in module.named_buffers(recurse=False):
if buf is not None and buf.device.type == "cpu":
CACHE.record(module, name, buf, is_buffer=True)
def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = False):
model_keys = set()
for name, _ in model.named_parameters(recurse=True):
model_keys.add(name)
for name, _ in model.named_buffers(recurse=True):
model_keys.add(name)
state_keys = set(state_dict.keys())
missing_keys = [k for k in model_keys if k not in state_keys]
unexpected_keys = [k for k in state_keys if k not in model_keys]
if strict:
error_msgs = []
if len(unexpected_keys) > 0:
error_msgs.append('Unexpected key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.append('Missing key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in missing_keys)))
if error_msgs:
raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs)))
for name, param in model.named_parameters(recurse=True):
if name not in state_keys:
continue
meta = state_dict.meta(name)
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
for name, buf in model.named_buffers(recurse=True):
if buf is None or name not in state_keys:
continue
meta = state_dict.meta(name)
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
register_module_weights(model, state_dict)
register_lazy_modules(model, state_dict)
attach_disk_weight_hooks(model)
return missing_keys, unexpected_keys

View File

@ -785,7 +785,7 @@ class ModelPatcher:
m.comfy_patched_weights = True m.comfy_patched_weights = True
for x in load_completely: for x in load_completely:
x[2].to(device_to) comfy.disk_weights.module_to(x[2], device_to)
for x in offloaded: for x in offloaded:
n = x[1] n = x[1]
@ -800,7 +800,7 @@ class ModelPatcher:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) comfy.disk_weights.module_to(self.model, device_to)
mem_counter = self.model_size() mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
@ -857,7 +857,7 @@ class ModelPatcher:
self.backup.clear() self.backup.clear()
if device_to is not None: if device_to is not None:
self.model.to(device_to) comfy.disk_weights.module_to(self.model, device_to)
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0 self.model.model_offload_buffer_memory = 0

View File

@ -442,6 +442,12 @@ class StreamStateDict(collections.abc.MutableMapping):
raise KeyError(key) raise KeyError(key)
if device is None: if device is None:
device = self._device device = self._device
if device.type == "meta":
meta = self._index.meta(key)
target_dtype = dtype or meta.dtype
if dtype is not None and dtype != meta.dtype:
_validate_dtype_conversion(meta.dtype, dtype)
return torch.empty(meta.shape, dtype=target_dtype, device="meta")
if allow_gds is None: if allow_gds is None:
allow_gds = self._allow_gds allow_gds = self._allow_gds
meta = self._index.meta(key) meta = self._index.meta(key)
@ -559,6 +565,37 @@ class _BaseViewStateDict(MutableMapping):
t = t.to(dtype=dtype) t = t.to(dtype=dtype)
return t return t
class DeviceViewStateDict(_BaseViewStateDict):
def __init__(
self,
base: MutableMapping,
device: torch.device,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
mutate_base: bool = False,
):
super().__init__(base, mutate_base=mutate_base)
self._device = device
self._allow_gds = allow_gds
self._pin_if_cpu = pin_if_cpu
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
device = self._device if device is None else device
allow_gds = self._allow_gds if allow_gds is None else allow_gds
pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
return super().get_tensor(
key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
)
def meta(self, key: str): def meta(self, key: str):
if key in self._overrides: if key in self._overrides:
t = self._overrides[key] t = self._overrides[key]

View File

@ -26,6 +26,7 @@ import os
import comfy.utils import comfy.utils
import comfy.safetensors_stream import comfy.safetensors_stream
import comfy.disk_weights
from . import clip_vision from . import clip_vision
from . import gligen from . import gligen
@ -125,7 +126,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt): if not model_management.supports_cast(load_device, dt):
load_device = offload_device load_device = offload_device
if params['device'] != offload_device: if params['device'] != offload_device:
self.cond_stage_model.to(offload_device) comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
logging.warning("Had to shift TE back.") logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
@ -671,7 +672,7 @@ class VAE:
if dtype is None: if dtype is None:
dtype = model_management.vae_dtype(self.device, self.working_dtypes) dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype) comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
self.output_device = model_management.intermediate_device() self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -1546,7 +1547,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = comfy.disk_weights.module_to(model, offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:

View File

@ -168,6 +168,8 @@ def state_dict_meta(state_dict, key):
def load_state_dict(model, state_dict, strict=False, assign=False): def load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict): if is_stream_state_dict(state_dict):
if comfy.disk_weights.disk_weights_enabled():
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
comfy.disk_weights.register_module_weights(model, state_dict) comfy.disk_weights.register_module_weights(model, state_dict)
comfy.disk_weights.attach_disk_weight_hooks(model) comfy.disk_weights.attach_disk_weight_hooks(model)
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign) missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
@ -900,7 +902,10 @@ def copy_to_param(obj, attr, value):
for name in attrs[:-1]: for name in attrs[:-1]:
obj = getattr(obj, name) obj = getattr(obj, name)
prev = getattr(obj, attrs[-1]) prev = getattr(obj, attrs[-1])
prev.data.copy_(value) if prev.device.type == "meta":
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad))
else:
prev.data.copy_(value)
def get_attr(obj, attr: str): def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation. """Retrieves a nested attribute from an object using dot notation.

View File

@ -144,3 +144,39 @@ def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
assert model.weight.device.type != "meta" assert model.weight.device.type != "meta"
finally: finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled) comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
if importlib.util.find_spec("fastsafetensors") is None:
pytest.skip("fastsafetensors not installed")
import comfy.utils
import comfy.disk_weights
prev_cache = comfy.disk_weights.CACHE.max_bytes
prev_gds = comfy.disk_weights.ALLOW_GDS
prev_pin = comfy.disk_weights.PIN_IF_CPU
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
try:
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
sd = comfy.utils.load_torch_file(path, safe_load=True)
model = torch.nn.Linear(4, 4, bias=True)
calls = []
original = sd._file.read_tensor
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
calls.append(meta)
return original(meta, device, dtype, allow_gds, pin_if_cpu)
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
comfy.utils.load_state_dict(model, sd, strict=True)
assert model.weight.device.type == "meta"
assert calls == []
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
assert model.weight.device.type == "cpu"
assert len(calls) == 2
finally:
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)