mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 20:00:17 +08:00
Enable lazy disk-tier loading for streaming weights
This commit is contained in:
parent
f925f8fa77
commit
5f2188e31b
36
DESIGN.md
36
DESIGN.md
@ -27,31 +27,31 @@
|
|||||||
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
|
- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
|
||||||
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
|
- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
|
||||||
|
|
||||||
## Strategy summary (no coding performed yet)
|
## Strategy summary (implemented)
|
||||||
|
|
||||||
### Streaming safetensors mapping (no full dict materialization)
|
### Streaming safetensors mapping (no full dict materialization)
|
||||||
- Introduce a new module `comfy/safetensors_stream.py` with:
|
- [x] Introduce a new module `comfy/safetensors_stream.py` with:
|
||||||
- `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
|
- [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
|
||||||
- `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
|
- [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
|
||||||
- Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
|
- [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
|
||||||
|
|
||||||
### Range reads and tiering
|
### Range reads and tiering
|
||||||
- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
|
- [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
|
||||||
- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
|
- [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
|
||||||
- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
|
- [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
|
||||||
|
|
||||||
### Disk tier integration
|
### Disk tier integration
|
||||||
- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
|
- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
|
||||||
- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
|
- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
|
||||||
- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
|
- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
|
||||||
|
|
||||||
### Pipeline refactors
|
### Pipeline refactors
|
||||||
- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
|
- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
|
||||||
- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
|
- [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
|
||||||
- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
|
- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
|
||||||
- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
|
- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
|
||||||
- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
|
- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
|
||||||
|
|
||||||
### Tests and docs
|
### Tests and docs
|
||||||
- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
|
- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
|
||||||
- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
|
- [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import logging
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
|
import comfy.disk_weights
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
@ -385,7 +386,7 @@ class ControlLora(ControlNet):
|
|||||||
controlnet_config["operations"] = control_lora_ops
|
controlnet_config["operations"] = control_lora_ops
|
||||||
controlnet_config["dtype"] = dtype
|
controlnet_config["dtype"] = dtype
|
||||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
self.control_model.to(comfy.model_management.get_torch_device())
|
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
|
||||||
diffusion_model = model.diffusion_model
|
diffusion_model = model.diffusion_model
|
||||||
sd = diffusion_model.state_dict()
|
sd = diffusion_model.state_dict()
|
||||||
|
|
||||||
@ -816,8 +817,8 @@ class T2IAdapter(ControlBase):
|
|||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
if self.control_input is None:
|
if self.control_input is None:
|
||||||
self.t2i_model.to(x_noisy.dtype)
|
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
|
||||||
self.t2i_model.to(self.device)
|
comfy.disk_weights.module_to(self.t2i_model, self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|
||||||
|
|||||||
@ -21,14 +21,18 @@ from __future__ import annotations
|
|||||||
import collections
|
import collections
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Dict, MutableMapping, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from . import safetensors_stream
|
||||||
|
|
||||||
|
|
||||||
ALLOW_GDS = False
|
ALLOW_GDS = False
|
||||||
PIN_IF_CPU = False
|
PIN_IF_CPU = False
|
||||||
DISK_WEIGHTS_ENABLED = False
|
DISK_WEIGHTS_ENABLED = False
|
||||||
|
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
|
||||||
|
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -123,6 +127,15 @@ class DiskWeightCache:
|
|||||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||||
return freed
|
return freed
|
||||||
|
|
||||||
|
def remove_module(self, module: torch.nn.Module):
|
||||||
|
to_remove = []
|
||||||
|
for key, entry in self._entries.items():
|
||||||
|
if entry.module_ref() is module:
|
||||||
|
to_remove.append(key)
|
||||||
|
for key in to_remove:
|
||||||
|
entry = self._entries.pop(key)
|
||||||
|
self.current_bytes -= entry.size_bytes
|
||||||
|
|
||||||
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
|
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for key, entry in self._entries.items():
|
for key, entry in self._entries.items():
|
||||||
@ -183,7 +196,61 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
|
|||||||
CACHE.record(module, name, buf, is_buffer=True)
|
CACHE.record(module, name, buf, is_buffer=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LazyModuleState:
|
||||||
|
state_dict: MutableMapping
|
||||||
|
prefix: str
|
||||||
|
loaded: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _has_custom_load(module: torch.nn.Module) -> bool:
|
||||||
|
return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT
|
||||||
|
|
||||||
|
|
||||||
|
def register_lazy_modules(model: torch.nn.Module, state_dict):
|
||||||
|
if not hasattr(state_dict, "keys"):
|
||||||
|
return
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if not _has_custom_load(module):
|
||||||
|
continue
|
||||||
|
prefix = f"{name}." if name else ""
|
||||||
|
if prefix:
|
||||||
|
has_key = False
|
||||||
|
for param_name in module._parameters.keys():
|
||||||
|
if f"{prefix}{param_name}" in state_dict:
|
||||||
|
has_key = True
|
||||||
|
break
|
||||||
|
if not has_key:
|
||||||
|
for buf_name in module._buffers.keys():
|
||||||
|
if f"{prefix}{buf_name}" in state_dict:
|
||||||
|
has_key = True
|
||||||
|
break
|
||||||
|
if not has_key:
|
||||||
|
continue
|
||||||
|
view = safetensors_stream.FilterViewStateDict(
|
||||||
|
state_dict, lambda k, p=prefix: k.startswith(p), mutate_base=False
|
||||||
|
)
|
||||||
|
LAZY_MODULE_STATE[module] = LazyModuleState(state_dict=view, prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||||
|
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||||
|
if lazy_state is not None:
|
||||||
|
CACHE.remove_module(module)
|
||||||
|
refs = REGISTRY.get(module)
|
||||||
|
if refs:
|
||||||
|
for ref_name, disk_ref in refs.items():
|
||||||
|
shape = getattr(disk_ref.meta, "shape", None)
|
||||||
|
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||||
|
if shape is None or dtype is None:
|
||||||
|
continue
|
||||||
|
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||||
|
if disk_ref.is_buffer:
|
||||||
|
module._buffers[ref_name] = meta_tensor
|
||||||
|
else:
|
||||||
|
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||||
|
lazy_state.loaded = False
|
||||||
|
return
|
||||||
ref = REGISTRY.get(module)
|
ref = REGISTRY.get(module)
|
||||||
if not ref or name not in ref:
|
if not ref or name not in ref:
|
||||||
return
|
return
|
||||||
@ -222,6 +289,10 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
|
|||||||
|
|
||||||
|
|
||||||
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
|
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
|
||||||
|
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||||
|
if lazy_state is not None and not lazy_state.loaded:
|
||||||
|
_materialize_module_from_state_dict(module, lazy_state, target_device)
|
||||||
|
return
|
||||||
refs = REGISTRY.get(module)
|
refs = REGISTRY.get(module)
|
||||||
if not refs:
|
if not refs:
|
||||||
return
|
return
|
||||||
@ -236,11 +307,14 @@ def ensure_module_materialized(module: torch.nn.Module, target_device: torch.dev
|
|||||||
continue
|
continue
|
||||||
if current is None:
|
if current is None:
|
||||||
continue
|
continue
|
||||||
if current.device.type != "meta":
|
if current.device.type == "meta":
|
||||||
|
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
|
||||||
|
elif current.device != target_device:
|
||||||
|
tensor = current.to(device=target_device)
|
||||||
|
else:
|
||||||
if current.device.type == "cpu":
|
if current.device.type == "cpu":
|
||||||
CACHE.touch(module, name)
|
CACHE.touch(module, name)
|
||||||
continue
|
continue
|
||||||
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
|
|
||||||
if is_buffer:
|
if is_buffer:
|
||||||
module._buffers[name] = tensor
|
module._buffers[name] = tensor
|
||||||
else:
|
else:
|
||||||
@ -273,3 +347,138 @@ def evict_ram_cache(bytes_to_free: int):
|
|||||||
if bytes_to_free <= 0:
|
if bytes_to_free <= 0:
|
||||||
return 0
|
return 0
|
||||||
return CACHE.evict_bytes(bytes_to_free)
|
return CACHE.evict_bytes(bytes_to_free)
|
||||||
|
|
||||||
|
|
||||||
|
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
|
||||||
|
if not disk_weights_enabled():
|
||||||
|
return
|
||||||
|
for submodule in module.modules():
|
||||||
|
ensure_module_materialized(submodule, target_device)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
||||||
|
if "device" in kwargs and kwargs["device"] is not None:
|
||||||
|
return torch.device(kwargs["device"])
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.device):
|
||||||
|
return arg
|
||||||
|
if isinstance(arg, str):
|
||||||
|
return torch.device(arg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||||
|
for param in module.parameters(recurse=True):
|
||||||
|
if param is not None and param.device.type != "meta":
|
||||||
|
return param.device
|
||||||
|
for buf in module.buffers(recurse=True):
|
||||||
|
if buf is not None and buf.device.type != "meta":
|
||||||
|
return buf.device
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||||
|
if disk_weights_enabled():
|
||||||
|
target_device = _extract_to_device(args, kwargs)
|
||||||
|
if target_device is None:
|
||||||
|
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||||
|
materialize_module_tree(module, target_device)
|
||||||
|
return module.to(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool):
|
||||||
|
parts = name.split(".")
|
||||||
|
module = model
|
||||||
|
for part in parts[:-1]:
|
||||||
|
module = getattr(module, part)
|
||||||
|
attr = parts[-1]
|
||||||
|
if is_buffer:
|
||||||
|
module._buffers[attr] = tensor
|
||||||
|
else:
|
||||||
|
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: LazyModuleState, target_device: torch.device):
|
||||||
|
missing_keys = []
|
||||||
|
unexpected_keys = []
|
||||||
|
error_msgs = []
|
||||||
|
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
||||||
|
state_dict = safetensors_stream.DeviceViewStateDict(
|
||||||
|
lazy_state.state_dict,
|
||||||
|
device=target_device,
|
||||||
|
allow_gds=ALLOW_GDS,
|
||||||
|
pin_if_cpu=PIN_IF_CPU,
|
||||||
|
mutate_base=False,
|
||||||
|
)
|
||||||
|
factory_device = None
|
||||||
|
if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs:
|
||||||
|
factory_device = module.factory_kwargs["device"]
|
||||||
|
module.factory_kwargs["device"] = target_device
|
||||||
|
try:
|
||||||
|
module._load_from_state_dict(
|
||||||
|
state_dict,
|
||||||
|
lazy_state.prefix,
|
||||||
|
local_metadata,
|
||||||
|
False,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
)
|
||||||
|
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||||
|
for hook in module._load_state_dict_post_hooks.values():
|
||||||
|
out = hook(module, incompatible)
|
||||||
|
if out is not None:
|
||||||
|
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||||
|
finally:
|
||||||
|
if factory_device is not None:
|
||||||
|
module.factory_kwargs["device"] = factory_device
|
||||||
|
if len(error_msgs) > 0:
|
||||||
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
lazy_state.loaded = True
|
||||||
|
for name, param in module.named_parameters(recurse=False):
|
||||||
|
if param.device.type == "cpu":
|
||||||
|
CACHE.record(module, name, param, is_buffer=False)
|
||||||
|
for name, buf in module.named_buffers(recurse=False):
|
||||||
|
if buf is not None and buf.device.type == "cpu":
|
||||||
|
CACHE.record(module, name, buf, is_buffer=True)
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = False):
|
||||||
|
model_keys = set()
|
||||||
|
for name, _ in model.named_parameters(recurse=True):
|
||||||
|
model_keys.add(name)
|
||||||
|
for name, _ in model.named_buffers(recurse=True):
|
||||||
|
model_keys.add(name)
|
||||||
|
|
||||||
|
state_keys = set(state_dict.keys())
|
||||||
|
missing_keys = [k for k in model_keys if k not in state_keys]
|
||||||
|
unexpected_keys = [k for k in state_keys if k not in model_keys]
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
error_msgs = []
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
error_msgs.append('Unexpected key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
error_msgs.append('Missing key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||||
|
if error_msgs:
|
||||||
|
raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
for name, param in model.named_parameters(recurse=True):
|
||||||
|
if name not in state_keys:
|
||||||
|
continue
|
||||||
|
meta = state_dict.meta(name)
|
||||||
|
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||||
|
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||||
|
|
||||||
|
for name, buf in model.named_buffers(recurse=True):
|
||||||
|
if buf is None or name not in state_keys:
|
||||||
|
continue
|
||||||
|
meta = state_dict.meta(name)
|
||||||
|
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||||
|
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||||
|
|
||||||
|
register_module_weights(model, state_dict)
|
||||||
|
register_lazy_modules(model, state_dict)
|
||||||
|
attach_disk_weight_hooks(model)
|
||||||
|
return missing_keys, unexpected_keys
|
||||||
|
|||||||
@ -785,7 +785,7 @@ class ModelPatcher:
|
|||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x[2].to(device_to)
|
comfy.disk_weights.module_to(x[2], device_to)
|
||||||
|
|
||||||
for x in offloaded:
|
for x in offloaded:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -800,7 +800,7 @@ class ModelPatcher:
|
|||||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
comfy.disk_weights.module_to(self.model, device_to)
|
||||||
mem_counter = self.model_size()
|
mem_counter = self.model_size()
|
||||||
|
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
@ -857,7 +857,7 @@ class ModelPatcher:
|
|||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
comfy.disk_weights.module_to(self.model, device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
self.model.model_offload_buffer_memory = 0
|
self.model.model_offload_buffer_memory = 0
|
||||||
|
|||||||
@ -442,6 +442,12 @@ class StreamStateDict(collections.abc.MutableMapping):
|
|||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self._device
|
device = self._device
|
||||||
|
if device.type == "meta":
|
||||||
|
meta = self._index.meta(key)
|
||||||
|
target_dtype = dtype or meta.dtype
|
||||||
|
if dtype is not None and dtype != meta.dtype:
|
||||||
|
_validate_dtype_conversion(meta.dtype, dtype)
|
||||||
|
return torch.empty(meta.shape, dtype=target_dtype, device="meta")
|
||||||
if allow_gds is None:
|
if allow_gds is None:
|
||||||
allow_gds = self._allow_gds
|
allow_gds = self._allow_gds
|
||||||
meta = self._index.meta(key)
|
meta = self._index.meta(key)
|
||||||
@ -559,6 +565,37 @@ class _BaseViewStateDict(MutableMapping):
|
|||||||
t = t.to(dtype=dtype)
|
t = t.to(dtype=dtype)
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceViewStateDict(_BaseViewStateDict):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base: MutableMapping,
|
||||||
|
device: torch.device,
|
||||||
|
allow_gds: Optional[bool] = None,
|
||||||
|
pin_if_cpu: bool = False,
|
||||||
|
mutate_base: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(base, mutate_base=mutate_base)
|
||||||
|
self._device = device
|
||||||
|
self._allow_gds = allow_gds
|
||||||
|
self._pin_if_cpu = pin_if_cpu
|
||||||
|
|
||||||
|
def get_tensor(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
*,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
allow_gds: Optional[bool] = None,
|
||||||
|
pin_if_cpu: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
device = self._device if device is None else device
|
||||||
|
allow_gds = self._allow_gds if allow_gds is None else allow_gds
|
||||||
|
pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
|
||||||
|
return super().get_tensor(
|
||||||
|
key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
|
||||||
|
)
|
||||||
|
|
||||||
def meta(self, key: str):
|
def meta(self, key: str):
|
||||||
if key in self._overrides:
|
if key in self._overrides:
|
||||||
t = self._overrides[key]
|
t = self._overrides[key]
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import os
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.safetensors_stream
|
import comfy.safetensors_stream
|
||||||
|
import comfy.disk_weights
|
||||||
|
|
||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
from . import gligen
|
from . import gligen
|
||||||
@ -125,7 +126,7 @@ class CLIP:
|
|||||||
if not model_management.supports_cast(load_device, dt):
|
if not model_management.supports_cast(load_device, dt):
|
||||||
load_device = offload_device
|
load_device = offload_device
|
||||||
if params['device'] != offload_device:
|
if params['device'] != offload_device:
|
||||||
self.cond_stage_model.to(offload_device)
|
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
|
||||||
logging.warning("Had to shift TE back.")
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
@ -671,7 +672,7 @@ class VAE:
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||||
self.vae_dtype = dtype
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
@ -1546,7 +1547,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = comfy.disk_weights.module_to(model, offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
|
|||||||
@ -168,6 +168,8 @@ def state_dict_meta(state_dict, key):
|
|||||||
|
|
||||||
def load_state_dict(model, state_dict, strict=False, assign=False):
|
def load_state_dict(model, state_dict, strict=False, assign=False):
|
||||||
if is_stream_state_dict(state_dict):
|
if is_stream_state_dict(state_dict):
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
|
||||||
comfy.disk_weights.register_module_weights(model, state_dict)
|
comfy.disk_weights.register_module_weights(model, state_dict)
|
||||||
comfy.disk_weights.attach_disk_weight_hooks(model)
|
comfy.disk_weights.attach_disk_weight_hooks(model)
|
||||||
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
||||||
@ -900,7 +902,10 @@ def copy_to_param(obj, attr, value):
|
|||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1])
|
||||||
prev.data.copy_(value)
|
if prev.device.type == "meta":
|
||||||
|
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad))
|
||||||
|
else:
|
||||||
|
prev.data.copy_(value)
|
||||||
|
|
||||||
def get_attr(obj, attr: str):
|
def get_attr(obj, attr: str):
|
||||||
"""Retrieves a nested attribute from an object using dot notation.
|
"""Retrieves a nested attribute from an object using dot notation.
|
||||||
|
|||||||
@ -144,3 +144,39 @@ def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
|
|||||||
assert model.weight.device.type != "meta"
|
assert model.weight.device.type != "meta"
|
||||||
finally:
|
finally:
|
||||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
|
||||||
|
if importlib.util.find_spec("fastsafetensors") is None:
|
||||||
|
pytest.skip("fastsafetensors not installed")
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.disk_weights
|
||||||
|
|
||||||
|
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
||||||
|
prev_gds = comfy.disk_weights.ALLOW_GDS
|
||||||
|
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
||||||
|
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
||||||
|
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
|
||||||
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
|
model = torch.nn.Linear(4, 4, bias=True)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
original = sd._file.read_tensor
|
||||||
|
|
||||||
|
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
||||||
|
calls.append(meta)
|
||||||
|
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
||||||
|
|
||||||
|
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
||||||
|
comfy.utils.load_state_dict(model, sd, strict=True)
|
||||||
|
assert model.weight.device.type == "meta"
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
|
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
|
||||||
|
assert model.weight.device.type == "cpu"
|
||||||
|
assert len(calls) == 2
|
||||||
|
finally:
|
||||||
|
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user