diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 000000000..19a1bfc5d --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,57 @@ +# Disk tier safetensors streaming design audit (ComfyUI) + +## Mandatory research audit (verified call sites) + +### ComfyUI load path + eager materialization sites +- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】 +- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】 +- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】 +- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】 +- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】 +- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】 +- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】 + +### FastSageTensors (fastsafetensors) capability audit +- Header parsing and metadata: + - `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】 + - `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】 +- GDS + no-GDS low-level readers: + - `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】 + - GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】 +- DLPack wrapping: + - `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】 +- Torch framework interop: + - `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】 + +### VRAM/RAM offload logic (for extension) +- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】 +- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】 + +## Strategy summary (implemented) + +### Streaming safetensors mapping (no full dict materialization) +- [x] Introduce a new module `comfy/safetensors_stream.py` with: + - [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`). + - [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand. + - [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading. + +### Range reads and tiering +- [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack. +- [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS. +- [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it. + +### Disk tier integration +- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`. +- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload. +- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`. + +### Pipeline refactors +- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading. +- [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy. +- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead. +- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads). +- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader. + +### Tests and docs +- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path. +- [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported. diff --git a/README.md b/README.md index 6d09758c0..550482e5c 100644 --- a/README.md +++ b/README.md @@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step | `--enable-manager` | Enable ComfyUI-Manager | | `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | | `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | +| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. | +| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. | + +### Disk tier for model weights + +When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand. + +If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead. # Running diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..bf3abf834 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -29,7 +29,7 @@ class AudioEncoderModel(): self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.model, sd, strict=False) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1716c3de7..5024caf8d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") +parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.") +parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index d5fc53497..837eba013 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -48,7 +48,7 @@ class ClipVisionModel(): self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.model, sd, strict=False) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..551f0bb18 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -25,6 +25,7 @@ import logging import comfy.utils import comfy.model_management import comfy.model_detection +import comfy.disk_weights import comfy.model_patcher import comfy.ops import comfy.latent_formats @@ -385,7 +386,7 @@ class ControlLora(ControlNet): controlnet_config["operations"] = control_lora_ops controlnet_config["dtype"] = dtype self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) - self.control_model.to(comfy.model_management.get_torch_device()) + comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() @@ -439,7 +440,7 @@ def controlnet_config(sd, model_options={}): return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device def controlnet_load_state_dict(control_model, sd): - missing, unexpected = control_model.load_state_dict(sd, strict=False) + missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) @@ -473,9 +474,9 @@ def load_controlnet_mmdit(sd, model_options={}): class ControlNetSD35(ControlNet): def pre_run(self, model, percent_to_timestep_function): if self.control_model.double_y_emb: - missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False) + missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False) else: - missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False) + missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False) super().pre_run(model, percent_to_timestep_function) def copy(self): @@ -748,9 +749,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): pass w = WeightsLoader() w.control_model = control_model - missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False) else: - missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) @@ -816,8 +817,8 @@ class T2IAdapter(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_input is None: - self.t2i_model.to(x_noisy.dtype) - self.t2i_model.to(self.device) + comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype) + comfy.disk_weights.module_to(self.t2i_model, self.device) self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) self.t2i_model.cpu() @@ -874,7 +875,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options else: return None - missing, unexpected = model_ad.load_state_dict(t2i_data) + missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True) if len(missing) > 0: logging.warning("t2i missing {}".format(missing)) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py new file mode 100644 index 000000000..07e9b5f11 --- /dev/null +++ b/comfy/disk_weights.py @@ -0,0 +1,1115 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from __future__ import annotations + +import collections +import logging +import weakref +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Dict, MutableMapping, Optional, Set + +import torch + +from . import safetensors_stream + + +ALLOW_GDS = False +PIN_IF_CPU = False +DISK_WEIGHTS_ENABLED = False +BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict +LAZY_MODULE_STATE = weakref.WeakKeyDictionary() +DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary() +_MISSING = object() + + +@dataclass +class DiskTensorRef: + state_dict: object + key: str + meta: object + requires_grad: bool + is_buffer: bool + + def load( + self, + device: torch.device, + allow_gds: bool, + pin_if_cpu: bool, + dtype_override: Optional[torch.dtype] = None, + ) -> torch.Tensor: + dtype = dtype_override or getattr(self.meta, "dtype", None) + if hasattr(self.state_dict, "get_tensor"): + return self.state_dict.get_tensor( + self.key, + device=device, + dtype=dtype, + allow_gds=allow_gds, + pin_if_cpu=pin_if_cpu, + ) + tensor = self.state_dict[self.key] + if device is not None and tensor.device != device: + tensor = tensor.to(device=device) + if dtype is not None and tensor.dtype != dtype: + tensor = tensor.to(dtype=dtype) + return tensor + + +class DiskWeightRegistry: + def __init__(self): + self._registry = weakref.WeakKeyDictionary() + + def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef): + module_refs = self._registry.setdefault(module, {}) + module_refs[name] = ref + + def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]: + return self._registry.get(module) + + def has(self, module: torch.nn.Module) -> bool: + return module in self._registry + + +@dataclass +class CacheEntry: + module_ref: weakref.ReferenceType + name: str + size_bytes: int + is_buffer: bool + + +class DiskWeightCache: + def __init__(self, max_bytes: int = 0): + self.max_bytes = max_bytes + self.current_bytes = 0 + self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict() + + def set_limit(self, max_bytes: int): + self.max_bytes = max_bytes + self._evict_if_needed() + + def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]: + return (id(module), name) + + def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool): + if tensor.device.type != "cpu": + return + size_bytes = tensor.numel() * tensor.element_size() + key = self._entry_key(module, name) + if key in self._entries: + entry = self._entries.pop(key) + self.current_bytes -= entry.size_bytes + module_ref = weakref.ref(module, self._drop_module_entries) + self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer) + self.current_bytes += size_bytes + self._evict_if_needed() + + def touch(self, module: torch.nn.Module, name: str): + key = self._entry_key(module, name) + if key in self._entries: + entry = self._entries.pop(key) + self._entries[key] = entry + + def evict_bytes(self, bytes_to_free: int): + freed = 0 + while self._entries and freed < bytes_to_free: + _, entry = self._entries.popitem(last=False) + freed += entry.size_bytes + self.current_bytes -= entry.size_bytes + module = entry.module_ref() + if module is not None: + _evict_module_weight(module, entry.name, entry.is_buffer) + return freed + + def remove_module(self, module: torch.nn.Module): + to_remove = [] + for key, entry in self._entries.items(): + if entry.module_ref() is module: + to_remove.append(key) + for key in to_remove: + entry = self._entries.pop(key) + self.current_bytes -= entry.size_bytes + + def _drop_module_entries(self, module_ref: weakref.ReferenceType): + to_remove = [] + for key, entry in self._entries.items(): + if entry.module_ref is module_ref: + to_remove.append(key) + for key in to_remove: + entry = self._entries.pop(key) + self.current_bytes -= entry.size_bytes + + def _evict_if_needed(self): + while self._entries and self.current_bytes > self.max_bytes: + _, entry = self._entries.popitem(last=False) + self.current_bytes -= entry.size_bytes + module = entry.module_ref() + if module is not None: + _evict_module_weight(module, entry.name, entry.is_buffer) + + +REGISTRY = DiskWeightRegistry() +CACHE = DiskWeightCache(0) +LOGGER = logging.getLogger(__name__) + + +def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True): + global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED + ALLOW_GDS = allow_gds + PIN_IF_CPU = pin_if_cpu + DISK_WEIGHTS_ENABLED = enabled + CACHE.set_limit(cache_bytes if enabled else 0) + if not enabled: + CACHE._entries.clear() + CACHE.current_bytes = 0 + + +def disk_weights_enabled() -> bool: + return DISK_WEIGHTS_ENABLED + + +def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""): + if not disk_weights_enabled(): + return + if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"): + return + for module_name, submodule in module.named_modules(): + module_prefix = f"{prefix}{module_name}." if module_name else prefix + for name, param in submodule.named_parameters(recurse=False): + key = f"{module_prefix}{name}" if module_prefix else name + if key in state_dict: + meta = state_dict.meta(key) + ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False) + REGISTRY.register(submodule, name, ref) + if param.device.type == "cpu": + CACHE.record(submodule, name, param, is_buffer=False) + for name, buf in submodule.named_buffers(recurse=False): + key = f"{module_prefix}{name}" if module_prefix else name + if key in state_dict and buf is not None: + meta = state_dict.meta(key) + ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True) + REGISTRY.register(submodule, name, ref) + if buf.device.type == "cpu": + CACHE.record(submodule, name, buf, is_buffer=True) + + +@dataclass +class LazyModuleState: + state_dict: MutableMapping + prefix: str + loaded: bool = False + + +@dataclass +class DiskMaterializationState: + loaded_keys: Set[str] = field(default_factory=set) + deferred_keys: Set[str] = field(default_factory=set) + loaded_bytes: int = 0 + deferred_bytes: int = 0 + future_dtypes: Dict[str, torch.dtype] = field(default_factory=dict) + + +def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState: + state = DISK_MATERIALIZATION_STATE.get(module) + if state is None: + state = DiskMaterializationState() + DISK_MATERIALIZATION_STATE[module] = state + return state + + +def _set_future_dtype(module: torch.nn.Module, name: str, dtype: Optional[torch.dtype]): + state = _get_materialization_state(module) + if dtype is None: + state.future_dtypes.pop(name, None) + else: + state.future_dtypes[name] = dtype + + +def _get_future_dtype(module: torch.nn.Module, name: str) -> Optional[torch.dtype]: + state = DISK_MATERIALIZATION_STATE.get(module) + if state is None: + return None + return state.future_dtypes.get(name) + + +def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState): + module.disk_loaded_weight_memory = state.loaded_bytes + module.disk_offload_buffer_memory = state.deferred_bytes + + +def _tensor_nbytes(tensor: torch.Tensor) -> int: + return tensor.numel() * tensor.element_size() + + +def _meta_nbytes(meta) -> Optional[int]: + return getattr(meta, "nbytes", None) + + +def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Tensor: + dtype = dtype_override or getattr(meta, "dtype", None) + shape = getattr(meta, "shape", None) + if dtype is None or shape is None: + raise KeyError("Missing metadata for meta tensor") + return torch.empty(shape, dtype=dtype, device="meta") + + +def _state_dict_meta(state_dict: MutableMapping, key: str): + if hasattr(state_dict, "meta"): + return state_dict.meta(key) + if hasattr(state_dict, "get_tensor"): + t = state_dict.get_tensor(key, device=torch.device("meta")) + else: + t = state_dict[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + + +def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, DiskTensorRef], state: DiskMaterializationState): + state.loaded_keys.clear() + state.deferred_keys.clear() + state.loaded_bytes = 0 + state.deferred_bytes = 0 + for name, ref in refs.items(): + if name in module._parameters: + tensor = module._parameters[name] + elif name in module._buffers: + tensor = module._buffers[name] + else: + continue + if tensor is None: + continue + nbytes = _meta_nbytes(ref.meta) or _tensor_nbytes(tensor) + if tensor.device.type == "meta": + state.deferred_keys.add(name) + state.deferred_bytes += nbytes + else: + state.loaded_keys.add(name) + state.loaded_bytes += nbytes + _update_disk_state_attrs(module, state) + + +def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]): + cpu_bytes = 0 + gpu_bytes = 0 + meta_bytes = 0 + total_bytes = 0 + for name, ref in refs.items(): + tensor = None + if name in module._parameters: + tensor = module._parameters[name] + elif name in module._buffers: + tensor = module._buffers[name] + if tensor is None: + continue + nbytes = _meta_nbytes(ref.meta) + if nbytes is None: + nbytes = _tensor_nbytes(tensor) + total_bytes += nbytes + if tensor.device.type == "meta": + meta_bytes += nbytes + elif tensor.device.type == "cpu": + cpu_bytes += nbytes + else: + gpu_bytes += nbytes + return total_bytes, cpu_bytes, gpu_bytes, meta_bytes + + +def _log_materialization( + module: torch.nn.Module, + target_device: torch.device, + free_mem: int, + refs: Dict[str, DiskTensorRef], + state: DiskMaterializationState, + context: str, +): + total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs) + if total_bytes == 0: + return + partial = meta_bytes > 0 + LOGGER.info( + "%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s " + "loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s", + context, + module.__class__.__name__, + target_device, + total_bytes / (1024 * 1024), + free_mem / (1024 * 1024), + partial, + state.loaded_bytes / (1024 * 1024), + state.deferred_bytes / (1024 * 1024), + cpu_bytes / (1024 * 1024), + gpu_bytes / (1024 * 1024), + not partial, + ) + + +def _device_free_memory(device: torch.device) -> int: + from . import model_management + return int(model_management.get_free_memory(device)) + + +def _evict_ram_for_budget(required_bytes: int) -> int: + if required_bytes <= 0: + return 0 + freed = evict_ram_cache(required_bytes) + if freed < required_bytes: + from . import model_management + freed += model_management.evict_ram_to_disk(required_bytes - freed) + return freed + + +def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int: + free_mem = _device_free_memory(device) + if device.type == "cpu" and free_mem < required_bytes: + _evict_ram_for_budget(required_bytes - free_mem) + free_mem = _device_free_memory(device) + return free_mem + + +def _choose_alternate_device(device: torch.device) -> Optional[torch.device]: + from . import model_management + if device.type == "cpu": + alt = model_management.get_torch_device() + if alt.type != "cpu": + return alt + else: + return torch.device("cpu") + return None + + +class _BudgetedStateDict(MutableMapping): + is_stream_state_dict = True + + def __init__( + self, + base: MutableMapping, + allowed_keys: Set[str], + device: torch.device, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + dtype_override: Optional[torch.dtype] = None, + overrides: Optional[Dict[str, torch.Tensor]] = None, + ): + self._base = base + self._allowed_keys = allowed_keys + self._device = device + self._allow_gds = allow_gds + self._pin_if_cpu = pin_if_cpu + self._dtype_override = dtype_override + self._overrides = overrides or {} + self._deleted: Set[str] = set() + + def _get_meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + return safetensors_stream.TensorMeta( + dtype=t.dtype, + shape=tuple(t.shape), + numel=t.numel(), + nbytes=_tensor_nbytes(t), + data_offsets=(0, _tensor_nbytes(t)), + filename="", + fst_dtype=None, + strides=tuple(t.stride()), + ) + if hasattr(self._base, "meta"): + return self._base.meta(key) + if hasattr(self._base, "get_tensor"): + t = self._base.get_tensor(key, device=torch.device("meta")) + else: + t = self._base[key] + return safetensors_stream.TensorMeta( + dtype=t.dtype, + shape=tuple(t.shape), + numel=t.numel(), + nbytes=_tensor_nbytes(t), + data_offsets=(0, _tensor_nbytes(t)), + filename="", + fst_dtype=None, + strides=tuple(t.stride()), + ) + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + requested_dtype = dtype if dtype is not None else self._dtype_override + if key in self._overrides: + t = self._overrides[key] + if device is not None and t.device != device: + t = t.to(device=device) + if requested_dtype is not None and t.dtype != requested_dtype: + t = t.to(dtype=requested_dtype) + return t + if key in self._deleted: + raise KeyError(key) + if key not in self._allowed_keys: + meta = self._get_meta(key) + target_dtype = requested_dtype or meta.dtype + return _meta_tensor(meta, dtype_override=target_dtype) + if hasattr(self._base, "get_tensor"): + return self._base.get_tensor( + key, + device=self._device if device is None else device, + dtype=requested_dtype, + allow_gds=self._allow_gds if allow_gds is None else allow_gds, + pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu, + ) + t = self._base[key] + if device is not None and t.device != device: + t = t.to(device=device) + if requested_dtype is not None and t.dtype != requested_dtype: + t = t.to(dtype=requested_dtype) + return t + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._deleted: + raise KeyError(key) + self._deleted.add(key) + + def __iter__(self): + for k in self._base.keys(): + if k in self._deleted: + continue + yield k + for k in self._overrides.keys(): + if k not in self._deleted: + yield k + + def __len__(self) -> int: + base_keys = list(self._base.keys()) + return len(base_keys) - len(self._deleted) + len(self._overrides) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + if key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if key not in self._base: + if default is _MISSING: + raise KeyError(key) + return default + self._deleted.add(key) + return self.get_tensor(key) + + def meta(self, key: str): + return self._get_meta(key) + +def _has_custom_load(module: torch.nn.Module) -> bool: + return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT + + +def register_lazy_modules(model: torch.nn.Module, state_dict): + if not hasattr(state_dict, "keys"): + return + for name, module in model.named_modules(): + if not _has_custom_load(module): + continue + prefix = f"{name}." if name else "" + if prefix: + has_key = False + for param_name in module._parameters.keys(): + if f"{prefix}{param_name}" in state_dict: + has_key = True + break + if not has_key: + for buf_name in module._buffers.keys(): + if f"{prefix}{buf_name}" in state_dict: + has_key = True + break + if not has_key: + continue + view = safetensors_stream.FilterViewStateDict( + state_dict, lambda k, p=prefix: k.startswith(p), mutate_base=False + ) + LAZY_MODULE_STATE[module] = LazyModuleState(state_dict=view, prefix=prefix) + + +def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): + lazy_state = LAZY_MODULE_STATE.get(module) + if lazy_state is not None: + CACHE.remove_module(module) + refs = REGISTRY.get(module) + if refs: + state = _get_materialization_state(module) + for ref_name, disk_ref in refs.items(): + shape = getattr(disk_ref.meta, "shape", None) + dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None) + if shape is None or dtype is None: + continue + meta_tensor = torch.empty(shape, dtype=dtype, device="meta") + if disk_ref.is_buffer: + module._buffers[ref_name] = meta_tensor + else: + module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + state.loaded_keys.discard(ref_name) + if ref_name not in state.deferred_keys: + state.deferred_keys.add(ref_name) + state.deferred_bytes += nbytes + state.loaded_bytes = max(0, state.loaded_bytes - nbytes) + _update_disk_state_attrs(module, state) + lazy_state.loaded = False + return + ref = REGISTRY.get(module) + if not ref or name not in ref: + return + disk_ref = ref[name] + shape = getattr(disk_ref.meta, "shape", None) + dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None) + if shape is None or dtype is None: + return + meta_tensor = torch.empty(shape, dtype=dtype, device="meta") + if is_buffer: + module._buffers[name] = meta_tensor + else: + module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + state = _get_materialization_state(module) + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + state.loaded_keys.discard(name) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += nbytes + state.loaded_bytes = max(0, state.loaded_bytes - nbytes) + _update_disk_state_attrs(module, state) + + +def _find_tensor_device(args, kwargs) -> Optional[torch.device]: + def check(obj): + if torch.is_tensor(obj): + return obj.device + if isinstance(obj, (list, tuple)): + for item in obj: + dev = check(item) + if dev is not None: + return dev + if isinstance(obj, dict): + for item in obj.values(): + dev = check(item) + if dev is not None: + return dev + return None + + dev = check(args) + if dev is not None: + return dev + return check(kwargs) + + +def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]: + def check(obj): + if torch.is_tensor(obj): + return obj.dtype + if isinstance(obj, (list, tuple)): + for item in obj: + dtype = check(item) + if dtype is not None: + return dtype + if isinstance(obj, dict): + for item in obj.values(): + dtype = check(item) + if dtype is not None: + return dtype + return None + + dtype = check(args) + if dtype is not None: + return dtype + return check(kwargs) + + +def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: Optional[torch.dtype]) -> Optional[torch.dtype]: + if manual_cast_dtype is not None: + return manual_cast_dtype + if input_dtype is None: + return None + if torch.is_floating_point(torch.empty((), dtype=input_dtype)): + return input_dtype + return None + + +def ensure_module_materialized( + module: torch.nn.Module, + target_device: torch.device, + fallback_device: Optional[torch.device] = None, + dtype_override: Optional[torch.dtype] = None, +): + lazy_state = LAZY_MODULE_STATE.get(module) + if lazy_state is not None: + _materialize_module_from_state_dict( + module, + lazy_state, + target_device, + dtype_override=dtype_override, + ) + return + refs = REGISTRY.get(module) + if not refs: + return + state = _get_materialization_state(module) + if dtype_override is not None: + for name in refs.keys(): + _set_future_dtype(module, name, dtype_override) + _rebuild_materialization_state(module, refs, state) + free_mem_start = _device_free_memory(target_device) + remaining_budget = free_mem_start + for name in sorted(refs.keys()): + disk_ref = refs[name] + if name in module._parameters: + current = module._parameters[name] + is_buffer = False + elif name in module._buffers: + current = module._buffers[name] + is_buffer = True + else: + continue + if current is None: + continue + target_dtype = dtype_override or _get_future_dtype(module, name) + if current.device.type != "meta" and current.device == target_device and ( + target_dtype is None or current.dtype == target_dtype + ): + if current.device.type == "cpu": + CACHE.touch(module, name) + continue + meta_nbytes = _meta_nbytes(disk_ref.meta) + if meta_nbytes is None: + continue + required_bytes = meta_nbytes + if target_device.type == "cpu": + free_mem = _maybe_free_ram_budget(target_device, required_bytes) + remaining_budget = min(remaining_budget, free_mem) + if required_bytes > remaining_budget: + if fallback_device is not None and fallback_device != target_device: + fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes) + if fallback_free >= required_bytes: + target_for_load = fallback_device + else: + continue + else: + continue + else: + target_for_load = target_device + if current.device.type == "meta": + tensor = disk_ref.load( + target_for_load, + ALLOW_GDS, + PIN_IF_CPU, + dtype_override=target_dtype, + ) + else: + if target_dtype is not None and current.dtype != target_dtype: + tensor = current.to(device=target_for_load, dtype=target_dtype) + else: + tensor = current.to(device=target_for_load) + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) + if tensor.device.type == "cpu": + CACHE.record(module, name, tensor, is_buffer=is_buffer) + remaining_budget = max(0, remaining_budget - required_bytes) + _rebuild_materialization_state(module, refs, state) + _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized") + + +def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): + if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE: + return + input_dtype = _find_tensor_dtype(args, kwargs) + manual_cast_dtype = getattr(module, "manual_cast_dtype", None) + dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype) + if getattr(module, "comfy_cast_weights", False): + target_device = torch.device("cpu") + fallback_device = _find_tensor_device(args, kwargs) + else: + target_device = _find_tensor_device(args, kwargs) or torch.device("cpu") + fallback_device = None + ensure_module_materialized( + module, + target_device, + fallback_device=fallback_device, + dtype_override=dtype_override, + ) + + +def attach_disk_weight_hooks(model: torch.nn.Module): + if not disk_weights_enabled(): + return + for module in model.modules(): + if getattr(module, "_disk_weight_hook_attached", False): + continue + module.register_forward_pre_hook(disk_weight_pre_hook) + module._disk_weight_hook_attached = True + + +def evict_ram_cache(bytes_to_free: int): + if bytes_to_free <= 0: + return 0 + return CACHE.evict_bytes(bytes_to_free) + + +def materialize_module_tree(module: torch.nn.Module, target_device: torch.device): + if not disk_weights_enabled(): + return + for submodule in module.modules(): + ensure_module_materialized(submodule, target_device) + + +def _extract_to_device(args, kwargs) -> Optional[torch.device]: + if "device" in kwargs and kwargs["device"] is not None: + return torch.device(kwargs["device"]) + for arg in args: + if isinstance(arg, torch.device): + return arg + if isinstance(arg, str): + return torch.device(arg) + return None + + +def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]: + if "dtype" in kwargs and kwargs["dtype"] is not None: + return kwargs["dtype"] + for arg in args: + if isinstance(arg, torch.dtype): + return arg + return None + + +def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: + for param in module.parameters(recurse=True): + if param is not None and param.device.type != "meta": + return param.device + for buf in module.buffers(recurse=True): + if buf is not None and buf.device.type != "meta": + return buf.device + return None + + +def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None): + def _move(tensor): + if tensor is None: + return None + if tensor.device.type == "meta": + return tensor + if dtype_override is not None and tensor.dtype != dtype_override: + return tensor.to(device=device_to, dtype=dtype_override) + return tensor.to(device=device_to) + + module._apply(_move) + return module + + +def offload_module_weights(module: torch.nn.Module) -> int: + if not disk_weights_enabled(): + return 0 + refs = REGISTRY.get(module) + if not refs: + return 0 + offloaded_bytes = 0 + if module in LAZY_MODULE_STATE: + ref_name = next(iter(refs.keys()), None) + if ref_name is not None: + _evict_module_weight(module, ref_name, False) + for disk_ref in refs.values(): + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + offloaded_bytes += nbytes + return offloaded_bytes + for name, disk_ref in refs.items(): + _evict_module_weight(module, name, disk_ref.is_buffer) + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + offloaded_bytes += nbytes + return offloaded_bytes + + +def module_to(module: torch.nn.Module, *args, **kwargs): + allow_materialize = kwargs.pop("allow_materialize", True) + if disk_weights_enabled(): + target_device = _extract_to_device(args, kwargs) + if target_device is None: + target_device = _find_existing_device(module) or torch.device("cpu") + if target_device.type == "meta": + offload_module_weights(module) + return module + if allow_materialize: + materialize_module_tree(module, target_device) + return module.to(*args, **kwargs) + dtype_override = _extract_to_dtype(args, kwargs) + return move_module_tensors(module, target_device, dtype_override=dtype_override) + return module.to(*args, **kwargs) + + +def load_module_tensor( + module: torch.nn.Module, + name: str, + device: torch.device, + *, + allow_alternate: bool = True, + record_cache: bool = True, + temporary: bool = False, + dtype_override: Optional[torch.dtype] = None, +) -> Optional[torch.Tensor]: + refs = REGISTRY.get(module) + if not refs or name not in refs: + return None + if name in module._parameters: + current = module._parameters[name] + is_buffer = False + elif name in module._buffers: + current = module._buffers[name] + is_buffer = True + else: + return None + if current is None: + return None + target_dtype = dtype_override or _get_future_dtype(module, name) + if dtype_override is not None: + _set_future_dtype(module, name, dtype_override) + if current.device.type != "meta": + if current.device != device or (target_dtype is not None and current.dtype != target_dtype): + if target_dtype is not None and current.dtype != target_dtype: + tensor = current.to(device=device, dtype=target_dtype) + else: + tensor = current.to(device=device) + if not temporary: + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=refs[name].requires_grad) + _rebuild_materialization_state(module, refs, _get_materialization_state(module)) + return tensor + return current + + disk_ref = refs[name] + required_bytes = _meta_nbytes(disk_ref.meta) + if required_bytes is None: + return current + free_mem_start = _device_free_memory(device) + free_mem = _maybe_free_ram_budget(device, required_bytes) + load_device = device + if free_mem < required_bytes and allow_alternate: + alt = _choose_alternate_device(device) + if alt is not None: + alt_free = _maybe_free_ram_budget(alt, required_bytes) + if alt_free >= required_bytes: + load_device = alt + else: + state = _get_materialization_state(module) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += required_bytes + _update_disk_state_attrs(module, state) + _log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred") + return current + else: + state = _get_materialization_state(module) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += required_bytes + _update_disk_state_attrs(module, state) + _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") + return current + elif free_mem < required_bytes: + state = _get_materialization_state(module) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += required_bytes + _update_disk_state_attrs(module, state) + _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") + return current + + tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype) + if temporary: + return tensor + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) + if tensor.device.type == "cpu" and record_cache: + CACHE.record(module, name, tensor, is_buffer=is_buffer) + state = _get_materialization_state(module) + _rebuild_materialization_state(module, refs, state) + _log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded") + return tensor + + +def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool): + parts = name.split(".") + module = model + for part in parts[:-1]: + module = getattr(module, part) + attr = parts[-1] + if is_buffer: + module._buffers[attr] = tensor + else: + module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad) + + +def _materialize_module_from_state_dict( + module: torch.nn.Module, + lazy_state: LazyModuleState, + target_device: torch.device, + dtype_override: Optional[torch.dtype] = None, +): + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(lazy_state.state_dict, "_metadata", None) + local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {}) + refs = REGISTRY.get(module) or {} + if dtype_override is not None: + for name in refs.keys(): + _set_future_dtype(module, name, dtype_override) + state = _get_materialization_state(module) + _rebuild_materialization_state(module, refs, state) + keys = sorted(lazy_state.state_dict.keys()) + existing = {} + for name, param in module.named_parameters(recurse=False): + key = f"{lazy_state.prefix}{name}" + if key in lazy_state.state_dict and param is not None and param.device.type != "meta": + existing[key] = param + for name, buf in module.named_buffers(recurse=False): + key = f"{lazy_state.prefix}{name}" + if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta": + existing[key] = buf + free_mem_start = _device_free_memory(target_device) + remaining_budget = free_mem_start + allowed = set(existing.keys()) + for key in keys: + if key in allowed: + continue + meta = _state_dict_meta(lazy_state.state_dict, key) + required = _meta_nbytes(meta) + if required is None: + continue + if target_device.type == "cpu": + free_mem = _maybe_free_ram_budget(target_device, required) + remaining_budget = min(remaining_budget, free_mem) + if required <= remaining_budget: + allowed.add(key) + remaining_budget = max(0, remaining_budget - required) + deferred_state_dict_keys = {key for key in keys if key not in allowed} + state_dict = _BudgetedStateDict( + lazy_state.state_dict, + allowed_keys=allowed, + device=target_device, + allow_gds=ALLOW_GDS, + pin_if_cpu=PIN_IF_CPU, + dtype_override=dtype_override, + overrides=existing, + ) + factory_device = None + if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs: + factory_device = module.factory_kwargs["device"] + module.factory_kwargs["device"] = target_device + try: + module._load_from_state_dict( + state_dict, + lazy_state.prefix, + local_metadata, + False, + missing_keys, + unexpected_keys, + error_msgs, + ) + incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible) + if out is not None: + raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.") + finally: + if factory_device is not None: + module.factory_kwargs["device"] = factory_device + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs))) + _rebuild_materialization_state(module, refs, state) + lazy_state.loaded = len(deferred_state_dict_keys) == 0 + _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed") + for name, param in module.named_parameters(recurse=False): + if param.device.type == "cpu": + CACHE.record(module, name, param, is_buffer=False) + for name, buf in module.named_buffers(recurse=False): + if buf is not None and buf.device.type == "cpu": + CACHE.record(module, name, buf, is_buffer=True) + + +def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = False): + model_keys = set() + for name, _ in model.named_parameters(recurse=True): + model_keys.add(name) + for name, _ in model.named_buffers(recurse=True): + model_keys.add(name) + + state_keys = set(state_dict.keys()) + missing_keys = [k for k in model_keys if k not in state_keys] + unexpected_keys = [k for k in state_keys if k not in model_keys] + + if strict: + error_msgs = [] + if len(unexpected_keys) > 0: + error_msgs.append('Unexpected key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.append('Missing key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in missing_keys))) + if error_msgs: + raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs))) + + for name, param in model.named_parameters(recurse=True): + if name not in state_keys: + continue + meta = state_dict.meta(name) + meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") + _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad) + + for name, buf in model.named_buffers(recurse=True): + if buf is None or name not in state_keys: + continue + meta = state_dict.meta(name) + meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") + _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False) + + register_module_weights(model, state_dict) + register_lazy_modules(model, state_dict) + attach_disk_weight_hooks(model) + return missing_keys, unexpected_keys diff --git a/comfy/gligen.py b/comfy/gligen.py index 1d7b6c2f4..c2cf7c6db 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -3,6 +3,7 @@ import torch from torch import nn from .ldm.modules.attention import CrossAttention, FeedForward import comfy.ops +import comfy.utils ops = comfy.ops.manual_cast @@ -282,7 +283,7 @@ def load_gligen(sd): gated = GatedSelfAttentionDense( query_dim, key_dim, n_heads, d_head) - gated.load_state_dict(n_sd, strict=False) + comfy.utils.load_state_dict(gated, n_sd, strict=False) output_list.append(gated) if "position_net.null_positive_feature" in sd_k: @@ -293,7 +294,7 @@ def load_gligen(sd): pass w = WeightsLoader() w.position_net = PositionNet(in_dim, out_dim) - w.load_state_dict(sd, strict=False) + comfy.utils.load_state_dict(w, sd, strict=False) gligen = Gligen(output_list, w.position_net, key_dim) return gligen diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8..19f0dafb4 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,4 +1,5 @@ import torch +import comfy.utils import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d @@ -112,7 +113,7 @@ class HunyuanVideo15SRModel(): self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return comfy.utils.load_state_dict(self.model, sd, strict=True) def get_sd(self): return self.model.state_dict() diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index a9111d3bd..c61883ba3 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -2,6 +2,7 @@ import json from dataclasses import dataclass import math import torch +import comfy.utils import torchaudio import comfy.model_management @@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module): self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) + comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False) + comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False) autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py index 62f24606c..1009d2d76 100644 --- a/comfy/ldm/mmaudio/vae/vae.py +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import comfy.utils import torch.nn as nn from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, @@ -152,7 +153,7 @@ class VAE(nn.Module): return dec, posterior def load_weights(self, src_dict) -> None: - self.load_state_dict(src_dict, strict=True) + comfy.utils.load_state_dict(self, src_dict, strict=True) @property def device(self) -> torch.device: @@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE: if name == '44k': return VAE_44k(**kwargs) raise ValueError(f'Unknown model: {name}') - diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..84766591b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -56,6 +56,7 @@ import comfy.conds import comfy.ops from enum import Enum from . import utils +from . import safetensors_stream import comfy.latent_formats import comfy.model_sampling import math @@ -299,20 +300,21 @@ class BaseModel(torch.nn.Module): return out def load_model_weights(self, sd, unet_prefix=""): - to_load = {} - keys = list(sd.keys()) - for k in keys: - if k.startswith(unet_prefix): - to_load[k[len(unet_prefix):]] = sd.pop(k) - + replace_prefix = {unet_prefix: ""} if unet_prefix else {} + if replace_prefix: + if utils.is_stream_state_dict(sd): + to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True) + else: + to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False) + else: + to_load = sd to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) if len(u) > 0: logging.warning("unet unexpected: {}".format(u)) - del to_load return self def process_latent_in(self, latent): @@ -751,8 +753,8 @@ class StableAudio1(BaseModel): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) - self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights) - self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) + utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True) + utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True) def extra_conds(self, **kwargs): out = {} diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0853b3aec..0b3a05659 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def sd_shape(state_dict, key): + return comfy.utils.state_dict_meta(state_dict, key).shape + def calculate_transformer_depth(prefix, state_dict_keys, state_dict): context_dim = None use_linear_in_transformer = False @@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) if len(transformer_keys) > 0: last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') - context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] - use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1] + use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2 time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross @@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model unet_config = {} - unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1] - patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2] + unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1] + patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2] unet_config["patch_size"] = patch_size final_layer = '{}final_layer.linear.weight'.format(key_prefix) if final_layer in state_dict: - unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size) + unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size) - unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64 + unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64 unet_config["input_size"] = None y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix) if y_key in state_dict_keys: - unet_config["adm_in_channels"] = state_dict[y_key].shape[1] + unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1] context_key = '{}context_embedder.weight'.format(key_prefix) if context_key in state_dict_keys: - in_features = state_dict[context_key].shape[1] - out_features = state_dict[context_key].shape[0] + in_features = sd_shape(state_dict, context_key)[1] + out_features = sd_shape(state_dict, context_key)[0] unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}} num_patches_key = '{}pos_embed'.format(key_prefix) if num_patches_key in state_dict_keys: - num_patches = state_dict[num_patches_key].shape[1] + num_patches = sd_shape(state_dict, num_patches_key)[1] unet_config["num_patches"] = num_patches unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches)) @@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) if text_mapper_name in state_dict_keys: unet_config['stable_cascade_stage'] = 'c' - w = state_dict[text_mapper_name] - if w.shape[0] == 1536: #stage c lite + w_shape = sd_shape(state_dict, text_mapper_name) + if w_shape[0] == 1536: #stage c lite unet_config['c_cond'] = 1536 unet_config['c_hidden'] = [1536, 1536] unet_config['nhead'] = [24, 24] unet_config['blocks'] = [[4, 12], [12, 4]] - elif w.shape[0] == 2048: #stage c full + elif w_shape[0] == 2048: #stage c full unet_config['c_cond'] = 2048 elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: unet_config['stable_cascade_stage'] = 'b' - w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] - if w.shape[-1] == 640: + w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)) + if w_shape[-1] == 640: unet_config['c_hidden'] = [320, 640, 1280, 1280] unet_config['nhead'] = [-1, -1, 20, 20] unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] - elif w.shape[-1] == 576: #stage b lite + elif w_shape[-1] == 576: #stage b lite unet_config['c_hidden'] = [320, 576, 1152, 1152] unet_config['nhead'] = [-1, 9, 18, 18] unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] @@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit unet_config = {} - unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] - unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] + unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1] + unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1] double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.') single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.') unet_config["n_double_layers"] = double_layers @@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config = {} unet_config["image_model"] = "hydit" unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') - unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] + unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2 unet_config["mlp_ratio"] = 4.3637 - if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968: + if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968: unet_config["size_cond"] = True unet_config["use_style_cond"] = True unet_config["image_model"] = "hydit1" @@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} - in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)] - out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)] + in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix)) + out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix)) dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels - dit_config["patch_size"] = list(in_w.shape[2:]) - dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) + dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels + dit_config["patch_size"] = list(in_w_shape[2:]) + dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"]) if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys): dit_config["vec_in_dim"] = 768 else: @@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["meanflow"] = False - dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] - dit_config["hidden_size"] = in_w.shape[0] + dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1] + dit_config["hidden_size"] = in_w_shape[0] dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = in_w.shape[0] // 128 + dit_config["num_heads"] = in_w_shape[0] // 128 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["theta"] = 256 @@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["use_cond_type_embedding"] = False if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: - dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0] dit_config["meanflow_sum"] = True else: dit_config["vision_in_dim"] = None @@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - w = state_dict[in_key] - dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) - dit_config["hidden_size"] = w.shape[0] + w_shape = sd_shape(state_dict, in_key) + dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w_shape[0] txt_in_key = "{}txt_in.weight".format(key_prefix) if txt_in_key in state_dict_keys: - w = state_dict[txt_in_key] - dit_config["context_in_dim"] = w.shape[1] - dit_config["hidden_size"] = w.shape[0] + w_shape = sd_shape(state_dict, txt_in_key) + dit_config["context_in_dim"] = w_shape[1] + dit_config["hidden_size"] = w_shape[0] vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: - dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] + dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1] else: dit_config["vec_in_dim"] = None @@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config = {} dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') - shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape + shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)) dit_config["attention_head_dim"] = shape[0] // 32 dit_config["cross_attention_dim"] = shape[1] if metadata is not None and "config" in metadata: @@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): y_key = "{}y_embedder.y_embedding".format(key_prefix) if y_key in state_dict_keys: - dit_config["model_max_length"] = state_dict[y_key].shape[0] + dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0] pe_key = "{}pos_embed".format(key_prefix) if pe_key in state_dict_keys: - dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size + dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix) @@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 concat_padding_mask = True - dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 - dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0] + dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0] dit_config["block_config"] = "FA-CA-MLP" dit_config["concat_padding_mask"] = concat_padding_mask dit_config["pos_emb_cls"] = "rope3d" @@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] - dit_config["dim"] = w.shape[0] - dit_config["cap_feat_dim"] = w.shape[1] + w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix)) + dit_config["dim"] = w_shape[0] + dit_config["cap_feat_dim"] = w_shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["qk_norm"] = True @@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 - ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: # NewBie - dit_config["clip_text_dim"] = ctd_weight.shape[0] + ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix) + if ctd_key in state_dict_keys: # NewBie + dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0] # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 @@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" - dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1] - out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4 + dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1] + out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4 dit_config["dim"] = dim dit_config["out_dim"] = out_dim dit_config["num_heads"] = dim // 128 - dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0] + dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') dit_config["patch_size"] = (1, 2, 2) dit_config["freq_dim"] = 256 @@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["qk_norm"] = True dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 - dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1] if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "vace" - dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: @@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" - flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) - if flf_weight is not None: - dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + flf_key = '{}img_emb.emb_pos'.format(key_prefix) + if flf_key in state_dict_keys: + dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1] - ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) - if ref_conv_weight is not None: - dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + ref_conv_key = '{}ref_conv.weight'.format(key_prefix) + if ref_conv_key in state_dict_keys: + dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D - in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix)) dit_config = {} dit_config["image_model"] = "hunyuan3d2" dit_config["in_channels"] = in_shape[1] - dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1] dit_config["hidden_size"] = in_shape[0] dit_config["mlp_ratio"] = 4.0 dit_config["num_heads"] = 16 @@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" - dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] + dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1] dit_config["context_dim"] = 1024 - dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0] + dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0] dit_config["mlp_ratio"] = 4.0 dit_config["num_heads"] = 16 dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}") @@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 concat_padding_mask = True - dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 - dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0] dit_config["concat_padding_mask"] = concat_padding_mask dit_config["crossattn_emb_channels"] = 1024 dit_config["pos_emb_cls"] = "rope3d" @@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" - dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" @@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 dit_config = {} - model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0] dit_config["model_dim"] = model_dim if model_dim in [4096, 2560]: # pro video and lite image dit_config["axes_dims"] = (32, 48, 48) @@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) elif model_dim == 1792: # lite video dit_config["axes_dims"] = (16, 24, 24) - dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0] dit_config["image_model"] = "kandinsky5" - dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] - dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0] + dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1] dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') return dit_config @@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): y_input = '{}label_emb.0.0.weight'.format(key_prefix) if y_input in state_dict_keys: unet_config["num_classes"] = "sequential" - unet_config["adm_in_channels"] = state_dict[y_input].shape[1] + unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1] else: unet_config["adm_in_channels"] = None - model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] - in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0] + in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1] out_key = '{}out.2.weight'.format(key_prefix) if out_key in state_dict: - out_channels = state_dict[out_key].shape[0] + out_channels = sd_shape(state_dict, out_key)[0] else: out_channels = 4 @@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): res_block_prefix = "{}0.in_layers.0.weight".format(prefix) if res_block_prefix in block_keys: last_res_blocks += 1 - last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels + last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels out = calculate_transformer_depth(prefix, state_dict_keys, state_dict) if out is not None: @@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_depth.append(transformer_count) if transformer_count > 0: - match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1] + match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1] attn_res *= 2 if attn_blocks == 0: @@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): match["transformer_depth"] = transformer_depth - match["model_channels"] = state_dict["conv_in.weight"].shape[0] - match["in_channels"] = state_dict["conv_in.weight"].shape[1] + match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0] + match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1] match["adm_in_channels"] = None if "class_embedding.linear_1.weight" in state_dict: - match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1] + match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1] elif "add_embedding.linear_1.weight" in state_dict: - match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] + match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1] SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, @@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') - hidden_size = state_dict["x_embedder.bias"].shape[0] + hidden_size = sd_shape(state_dict, "x_embedder.bias")[0] sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') - depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 + depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) else: return None diff --git a/comfy/model_management.py b/comfy/model_management.py index e5de4a5b5..c681bbd69 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,7 @@ import platform import weakref import gc import os +import comfy.disk_weights class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -540,7 +541,12 @@ class LoadedModel: freed = self.model.partially_unload(self.model.offload_device, memory_to_free) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) + offload_device = None + if comfy.disk_weights.disk_weights_enabled(): + offload_device = torch.device("meta") + self.model.detach(unpatch_weights, offload_device=offload_device) + if offload_device is not None and offload_device.type == "meta": + logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk") self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -594,6 +600,11 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() + if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled(): + logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024)) + freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) + if freed_cache < memory_required: + evict_ram_to_disk(memory_required - freed_cache) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -629,6 +640,34 @@ def free_memory(memory_required, device, keep_loaded=[]): soft_empty_cache() return unloaded_models + +def evict_ram_to_disk(memory_to_free, keep_loaded=[]): + if memory_to_free <= 0: + return 0 + if not comfy.disk_weights.disk_weights_enabled(): + return 0 + + freed = 0 + can_unload = [] + for i in range(len(current_loaded_models) - 1, -1, -1): + shift_model = current_loaded_models[i] + if shift_model not in keep_loaded and not shift_model.is_dead(): + loaded_memory = shift_model.model_loaded_memory() + if loaded_memory > 0: + can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + + for x in sorted(can_unload): + i = x[-1] + memory_needed = memory_to_free - freed + if memory_needed <= 0: + break + logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk") + freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed) + + if freed > 0: + logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024))) + return freed + def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -1135,6 +1174,16 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) +WEIGHTS_RAM_CACHE_BYTES = 0 +WEIGHTS_GDS_ENABLED = bool(args.weights_gds) +if args.weights_ram_cache_gb is not None: + WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3)) + comfy.disk_weights.configure( + WEIGHTS_RAM_CACHE_BYTES, + allow_gds=WEIGHTS_GDS_ENABLED, + pin_if_cpu=not args.disable_pinned_memory, + ) + PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): @@ -1291,7 +1340,10 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == "meta": + mem_free_total = sys.maxsize + mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..26fd30fa8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -34,6 +34,7 @@ import comfy.lora import comfy.model_management import comfy.patcher_extension import comfy.utils +import comfy.disk_weights from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -269,6 +270,8 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + comfy.disk_weights.attach_disk_weight_hooks(self.model) + def model_size(self): if self.size > 0: return self.size @@ -783,7 +786,7 @@ class ModelPatcher: m.comfy_patched_weights = True for x in load_completely: - x[2].to(device_to) + comfy.disk_weights.module_to(x[2], device_to) for x in offloaded: n = x[1] @@ -799,7 +802,7 @@ class ModelPatcher: logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: - self.model.to(device_to) + comfy.disk_weights.module_to(self.model, device_to) mem_counter = self.model_size() self.model.lowvram_patch_counter += patch_counter @@ -856,7 +859,7 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - self.model.to(device_to) + comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False) self.model.device = device_to self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -883,6 +886,9 @@ class ModelPatcher: if len(unload_list) > 0: NS = comfy.model_management.NUM_STREAMS offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS + remaining_ram = None + if device_to is not None and comfy.model_management.is_device_cpu(device_to): + remaining_ram = comfy.model_management.get_free_memory(device_to) for unload in unload_list: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: @@ -916,7 +922,24 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + freed_bytes = module_mem + if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled(): + freed_bytes = comfy.disk_weights.offload_module_weights(m) + if freed_bytes == 0: + freed_bytes = module_mem + else: + if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled(): + logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024)) + freed_bytes = comfy.disk_weights.offload_module_weights(m) + if freed_bytes == 0: + freed_bytes = module_mem + else: + if comfy.disk_weights.disk_weights_enabled(): + comfy.disk_weights.move_module_tensors(m, device_to) + else: + m.to(device_to) + if remaining_ram is not None: + remaining_ram = max(0, remaining_ram - module_mem) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: @@ -939,7 +962,7 @@ class ModelPatcher: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False - memory_freed += module_mem + memory_freed += freed_bytes offload_buffer = max(offload_buffer, potential_offload) offload_weight_factor.append(module_mem) offload_weight_factor.pop(0) @@ -953,7 +976,8 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed self.model.model_offload_buffer_memory = offload_buffer - logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) + target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to + logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): @@ -984,11 +1008,12 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used - def detach(self, unpatch_all=True): + def detach(self, unpatch_all=True, offload_device=None): self.eject_model() self.model_patches_to(self.offload_device) + target_device = self.offload_device if offload_device is None else offload_device if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + self.unpatch_model(target_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) return self.model @@ -1358,4 +1383,3 @@ class ModelPatcher: def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) - diff --git a/comfy/ops.py b/comfy/ops.py index 8156c42ff..d7a0c9ab2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,6 +19,7 @@ import torch import logging import comfy.model_management +import comfy.disk_weights from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm @@ -98,11 +99,35 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight_source = s.weight + bias_source = s.bias + if comfy.disk_weights.disk_weights_enabled(): + if weight_source.device.type == "meta": + loaded = comfy.disk_weights.load_module_tensor( + s, + "weight", + device, + temporary=True, + dtype_override=dtype, + ) + if loaded is not None: + weight_source = loaded + if bias_source is not None and bias_source.device.type == "meta": + loaded_bias = comfy.disk_weights.load_module_tensor( + s, + "bias", + device, + temporary=True, + dtype_override=bias_dtype, + ) + if loaded_bias is not None: + bias_source = loaded_bias + + weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) bias = None - if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + if bias_source is not None: + bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) @@ -532,9 +557,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec key = f"{prefix}{param_name}" value = state_dict.pop(key, None) if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) + if value.device.type != "meta": + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) manually_loaded_keys.append(key) return value @@ -551,11 +577,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec manually_loaded_keys = [weight_key] layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: + if layer_conf is not None and layer_conf.device.type != "meta": layer_conf = json.loads(layer_conf.numpy().tobytes()) + elif layer_conf is not None: + layer_conf = None if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + if weight.device.type == "meta": + self.weight = torch.nn.Parameter(weight, requires_grad=False) + else: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) @@ -601,10 +632,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec else: raise ValueError(f"Unsupported quantization format: {self.quant_format}") - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) + if weight.device.type == "meta": + self.weight = torch.nn.Parameter(weight, requires_grad=False) + else: + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), + requires_grad=False + ) for param_name in qconfig["parameters"]: if param_name in {"weight_scale", "weight_scale_2"}: @@ -614,7 +648,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _v = state_dict.pop(param_key, None) if _v is None: continue - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + if _v.device.type == "meta": + self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False)) + else: + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) diff --git a/comfy/safetensors_stream.py b/comfy/safetensors_stream.py new file mode 100644 index 000000000..61943e53f --- /dev/null +++ b/comfy/safetensors_stream.py @@ -0,0 +1,934 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from __future__ import annotations + +import collections +import ctypes +import importlib +import importlib.util +import os +import threading +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple + +import torch + + +_FST_MODULE = None +_FST_LOCK = threading.Lock() +_FST_LOADED = False +_GDS_INITIALIZED = False +_MISSING = object() +_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024 + + +def _require_fastsafetensors(): + global _FST_MODULE + with _FST_LOCK: + if _FST_MODULE is None: + if importlib.util.find_spec("fastsafetensors") is None: + raise ImportError( + "fastsafetensors is required for safetensors streaming. " + "Install it with: pip install 'fastsafetensors @ https://github.com/" + "foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'" + ) + _FST_MODULE = importlib.import_module("fastsafetensors") + return _FST_MODULE + + +def _init_fastsafetensors_lib(): + global _FST_LOADED + fst = _require_fastsafetensors() + if not _FST_LOADED: + fst.cpp.load_library_functions() + _FST_LOADED = True + return fst + + +def _init_gds(): + global _GDS_INITIALIZED + fst = _init_fastsafetensors_lib() + if not _GDS_INITIALIZED: + if fst.cpp.init_gds() != 0: + raise RuntimeError("fastsafetensors init_gds() failed") + _GDS_INITIALIZED = True + + +@dataclass(frozen=True) +class TensorMeta: + dtype: torch.dtype + shape: Tuple[int, ...] + numel: int + nbytes: int + data_offsets: Tuple[int, int] + filename: str + fst_dtype: object + strides: Tuple[int, ...] + + +class SafeTensorIndex: + def __init__(self, filename: str): + fst = _init_fastsafetensors_lib() + framework = fst.frameworks.get_framework_op("pytorch") + metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework) + self._filename = filename + self._metadata = metadata + self._framework = framework + from fastsafetensors.frameworks import _torch as fst_torch + self._dtype_map = fst_torch.dtype_convert + self._tensor_meta: Dict[str, TensorMeta] = {} + for key, frame in metadata.tensors.items(): + torch_dtype = self._dtype_map.get(frame.dtype, None) + if torch_dtype is None: + raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}") + numel = 1 + for s in frame.shape: + numel *= s + nbytes = numel * framework.get_dtype_size(frame.dtype) + self._tensor_meta[key] = TensorMeta( + dtype=torch_dtype, + shape=tuple(frame.shape), + numel=numel, + nbytes=nbytes, + data_offsets=(frame.data_offsets[0], frame.data_offsets[1]), + filename=filename, + fst_dtype=frame.dtype, + strides=tuple(frame.strides), + ) + + def keys(self) -> Iterable[str]: + return self._tensor_meta.keys() + + def has(self, key: str) -> bool: + return key in self._tensor_meta + + def meta(self, key: str) -> TensorMeta: + return self._tensor_meta[key] + + def metadata(self): + return self._metadata.metadata + + @property + def header_length(self) -> int: + return self._metadata.header_length + + @property + def size_bytes(self) -> int: + return self._metadata.size_bytes + + +class _SafeTensorFile: + def __init__(self, filename: str, index: SafeTensorIndex): + self.filename = filename + self.index = index + self._fd: Optional[int] = None + self._gds_handle = None + self._gds_reader = None + self._nogds_reader = None + self._refcount = 1 + + def acquire(self) -> "_SafeTensorFile": + self._refcount += 1 + return self + + def release(self): + self._refcount -= 1 + if self._refcount <= 0: + self.close() + + def close(self): + if self._fd is not None: + os.close(self._fd) + self._fd = None + self._gds_handle = None + + def _ensure_fd(self) -> int: + if self._fd is None: + self._fd = os.open(self.filename, os.O_RDONLY, 0o644) + return self._fd + + def _ensure_nogds_reader(self, use_cuda: bool): + fst = _init_fastsafetensors_lib() + if self._nogds_reader is None: + self._nogds_reader = fst.cpp.nogds_file_reader( + False, 16 * 1024, 16, use_cuda + ) + return self._nogds_reader + + def _ensure_gds_reader(self, use_cuda: bool): + fst = _init_fastsafetensors_lib() + if self._gds_reader is None: + self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda) + return self._gds_reader + + def _ensure_gds_handle(self, use_cuda: bool): + if self._gds_handle is None: + fst = _init_fastsafetensors_lib() + framework = fst.frameworks.get_framework_op("pytorch") + o_direct = _get_gds_o_direct(framework) + self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda) + return self._gds_handle + + def read_tensor( + self, + meta: TensorMeta, + device: torch.device, + dtype: Optional[torch.dtype], + allow_gds: bool, + pin_if_cpu: bool, + ) -> torch.Tensor: + fst = _init_fastsafetensors_lib() + framework = fst.frameworks.get_framework_op("pytorch") + device_is_cuda = device.type == "cuda" + if device_is_cuda and allow_gds: + _ensure_gds_ready(device) + tensor = self._read_tensor_gds( + fst, framework, meta, device, dtype + ) + return tensor + + cpu_tensor = self._read_tensor_nogds( + fst, framework, meta, torch.device("cpu"), dtype + ) + if device_is_cuda: + if pin_if_cpu: + cpu_tensor = cpu_tensor.pin_memory() + gpu_tensor = torch.empty_like(cpu_tensor, device=device) + gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu) + return gpu_tensor + return cpu_tensor + + def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]: + fst = _init_fastsafetensors_lib() + align = fst.cpp.get_alignment_size() + aligned_offset = (abs_start // align) * align + head = abs_start - aligned_offset + aligned_length = length + head + tail = aligned_length % align + if tail: + aligned_length += align - tail + return aligned_offset, aligned_length, head + + def _read_tensor_nogds( + self, + fst, + framework, + meta: TensorMeta, + device: torch.device, + dtype: Optional[torch.dtype], + ) -> torch.Tensor: + fd = self._ensure_fd() + reader = self._ensure_nogds_reader(use_cuda=False) + abs_start = self.index.header_length + meta.data_offsets[0] + length = meta.data_offsets[1] - meta.data_offsets[0] + chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT)) + chunk_bytes = max(1, chunk_bytes) + ptr_align = framework.get_device_ptr_align() + dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu") + buffer_length = 0 + buf_ptr = None + gbuf = None + try: + chunk_offset = 0 + while chunk_offset < length: + chunk_len = min(length - chunk_offset, chunk_bytes) + aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len) + needed = aligned_length + ptr_align + if buf_ptr is None or needed > buffer_length: + if buf_ptr is not None: + fst.cpp.cpu_free(buf_ptr) + buffer_length = needed + buf_ptr = fst.cpp.cpu_malloc(buffer_length) + gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False) + ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align + req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off) + if reader.wait_read(req) < 0: + raise RuntimeError("nogds_file_reader read failed") + src_ptr = gbuf.get_base_address() + ptr_off + head + dest_ptr = dest_tensor.data_ptr() + chunk_offset + ctypes.memmove(dest_ptr, src_ptr, chunk_len) + chunk_offset += chunk_len + except Exception: + if buf_ptr is not None: + fst.cpp.cpu_free(buf_ptr) + raise + if buf_ptr is not None: + fst.cpp.cpu_free(buf_ptr) + if dtype is not None and dtype != dest_tensor.dtype: + _validate_dtype_conversion(dest_tensor.dtype, dtype) + dest_tensor = dest_tensor.to(dtype=dtype) + return dest_tensor + + def _read_tensor_gds( + self, + fst, + framework, + meta: TensorMeta, + device: torch.device, + dtype: Optional[torch.dtype], + ) -> torch.Tensor: + reader = self._ensure_gds_reader(use_cuda=True) + handle = self._ensure_gds_handle(use_cuda=True) + abs_start = self.index.header_length + meta.data_offsets[0] + length = meta.data_offsets[1] - meta.data_offsets[0] + aligned_offset, aligned_length, head = self._aligned_range(abs_start, length) + ptr_align = framework.get_device_ptr_align() + buffer_length = aligned_length + ptr_align + fst_device = _fst_device_from_torch(fst, device) + gbuf = framework.alloc_tensor_memory(buffer_length, fst_device) + ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align + file_length = self.index.size_bytes + req = reader.submit_read( + handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length + ) + if reader.wait_read(req) < 0: + framework.free_tensor_memory(gbuf, fst_device) + raise RuntimeError("gds_file_reader read failed") + owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device)) + tensor = _dlpack_tensor_from_buffer( + fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner + ) + if dtype is not None and dtype != tensor.dtype: + _validate_dtype_conversion(tensor.dtype, dtype) + tensor = tensor.to(dtype=dtype) + return tensor + + +def _fst_device_from_torch(fst, device: torch.device): + if device.type == "cuda" and device.index is not None: + return fst.st_types.Device.from_str(f"cuda:{device.index}") + return fst.st_types.Device.from_str(device.type) + + +class _BufferOwner: + def __init__(self, free_fn): + self._free_fn = free_fn + + def __del__(self): + try: + self._free_fn() + except Exception: + pass + + +def _dlpack_tensor_from_buffer( + fst, + framework, + ptr: int, + meta: TensorMeta, + device: torch.device, + owner: Optional[_BufferOwner], +) -> torch.Tensor: + disk_dtype = framework.as_workaround_dtype(meta.fst_dtype) + dev = _fst_device_from_torch(fst, device) + dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev) + torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor + if disk_dtype != meta.fst_dtype: + torch_tensor = torch_tensor.view(meta.dtype) + if owner is not None: + torch_tensor._comfy_disk_buffer_owner = owner + return torch_tensor + + +def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype): + if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size(): + raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})") + + +def _get_gds_o_direct(framework) -> bool: + cuda_ver = framework.get_cuda_ver() + if cuda_ver and cuda_ver != "0.0": + ver_parts = cuda_ver.split("-", 1) + if len(ver_parts) == 2: + cudavers = list(map(int, ver_parts[1].split("."))) + if ver_parts[0] == "cuda": + return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)) + return True + return True + + +def _ensure_gds_ready(device: torch.device): + fst = _init_fastsafetensors_lib() + if not fst.common.is_gpu_found(): + raise RuntimeError( + "GPUDirect requested but GPU runtime library is missing. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0) + if gds_supported < 0: + raise RuntimeError( + "GPUDirect requested but is_gds_supported() failed. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + if not fst.cpp.is_cufile_found(): + raise RuntimeError( + "GPUDirect requested but libcufile is missing. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + if gds_supported == 0: + raise RuntimeError( + "GPUDirect requested but GDS is unsupported on this platform. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + _init_gds() + + +class StreamStateDict(collections.abc.MutableMapping): + is_stream_state_dict = True + + def __init__( + self, + index: SafeTensorIndex, + file: _SafeTensorFile, + device: torch.device, + allow_gds: bool = False, + ): + self._index = index + self._file = file + self._device = device + self._allow_gds = allow_gds + self._overrides: Dict[str, torch.Tensor] = {} + self._deleted: set[str] = set() + + @classmethod + def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict": + index = SafeTensorIndex(filename) + file = _SafeTensorFile(filename, index) + return cls(index, file, device, allow_gds=allow_gds) + + def close(self): + if self._file is not None: + self._file.release() + self._file = None + + def __del__(self): + try: + self.close() + except Exception: + pass + + def meta(self, key: str) -> TensorMeta: + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return TensorMeta( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + data_offsets=(0, numel * t.element_size()), + filename="", + fst_dtype=None, + strides=tuple(t.stride()), + ) + if key in self._deleted: + raise KeyError(key) + return self._index.meta(key) + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + if key in self._overrides: + t = self._overrides[key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + _validate_dtype_conversion(t.dtype, dtype) + t = t.to(dtype=dtype) + return t + if key in self._deleted: + raise KeyError(key) + if device is None: + device = self._device + if device.type == "meta": + meta = self._index.meta(key) + target_dtype = dtype or meta.dtype + if dtype is not None and dtype != meta.dtype: + _validate_dtype_conversion(meta.dtype, dtype) + return torch.empty(meta.shape, dtype=target_dtype, device="meta") + if allow_gds is None: + allow_gds = self._allow_gds + meta = self._index.meta(key) + return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu) + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._deleted: + raise KeyError(key) + if self._index.has(key): + self._deleted.add(key) + return + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + for k in self._index.keys(): + if k in self._deleted: + continue + if k in self._overrides: + continue + yield k + for k in self._overrides.keys(): + yield k + + def __len__(self) -> int: + base = len(self._index.keys()) + return base - len(self._deleted) + len(self._overrides) + + def __contains__(self, key: object) -> bool: + if not isinstance(key, str): + return False + if key in self._deleted: + return False + if key in self._overrides: + return True + return self._index.has(key) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + if key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if self._index.has(key): + self._deleted.add(key) + return self.get_tensor(key) + if default is _MISSING: + raise KeyError(key) + return default + + def copy(self) -> "StreamStateDict": + new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds) + new._overrides = dict(self._overrides) + new._deleted = set(self._deleted) + return new + + def metadata(self): + return self._index.metadata() + + +class _BaseViewStateDict(MutableMapping): + is_stream_state_dict = True + + def __init__(self, base: MutableMapping, mutate_base: bool = False): + self._base = base + self._mutate_base = mutate_base + self._overrides: Dict[str, torch.Tensor] = {} + self._deleted: set[str] = set() + + def _resolve_base_key(self, key: str) -> Optional[str]: + return key + + def _iter_base_keys(self) -> Iterable[str]: + return self._base.keys() + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + if key in self._overrides: + t = self._overrides[key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + _validate_dtype_conversion(t.dtype, dtype) + t = t.to(dtype=dtype) + return t + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if hasattr(self._base, "get_tensor"): + return self._base.get_tensor( + base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu + ) + t = self._base[base_key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + _validate_dtype_conversion(t.dtype, dtype) + t = t.to(dtype=dtype) + return t + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + base_key = self._resolve_base_key(key) + if self._mutate_base and base_key is not None and base_key in self._base: + self._base[base_key] = value + else: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if self._mutate_base and base_key in self._base: + del self._base[base_key] + else: + self._deleted.add(key) + + def __iter__(self) -> Iterator[str]: + for k in self._iter_base_keys(): + if k in self._deleted: + continue + yield k + for k in self._overrides.keys(): + yield k + + def __len__(self) -> int: + base_keys = list(self._iter_base_keys()) + return len(base_keys) - len(self._deleted) + len(self._overrides) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if self._mutate_base: + try: + return self._base.pop(base_key) + except KeyError: + if default is _MISSING: + raise + return default + self._deleted.add(key) + return self.get_tensor(key) + + def meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if hasattr(self._base, "meta"): + return self._base.meta(base_key) + t = self._base[base_key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + + +class DeviceViewStateDict(_BaseViewStateDict): + def __init__( + self, + base: MutableMapping, + device: torch.device, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + mutate_base: bool = False, + ): + super().__init__(base, mutate_base=mutate_base) + self._device = device + self._allow_gds = allow_gds + self._pin_if_cpu = pin_if_cpu + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + device = self._device if device is None else device + allow_gds = self._allow_gds if allow_gds is None else allow_gds + pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu + return super().get_tensor( + key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu + ) + + def meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if hasattr(self._base, "meta"): + return self._base.meta(base_key) + t = self._base[base_key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + base_key = self._resolve_base_key(key) + if self._mutate_base and base_key is not None and base_key in self._base: + self._base[base_key] = value + else: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if self._mutate_base and base_key in self._base: + del self._base[base_key] + else: + self._deleted.add(key) + + def __iter__(self) -> Iterator[str]: + for k in self._iter_base_keys(): + if k in self._deleted: + continue + yield k + for k in self._overrides.keys(): + yield k + + def __len__(self) -> int: + base_keys = list(self._iter_base_keys()) + return len(base_keys) - len(self._deleted) + len(self._overrides) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if self._mutate_base: + try: + return self._base.pop(base_key) + except KeyError: + if default is _MISSING: + raise + return default + self._deleted.add(key) + return self.get_tensor(key) + + +class FilterViewStateDict(_BaseViewStateDict): + def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False): + super().__init__(base, mutate_base=mutate_base) + self._predicate = predicate + + def _resolve_base_key(self, key: str) -> Optional[str]: + if self._predicate(key): + return key + return None + + def _iter_base_keys(self) -> Iterable[str]: + for k in self._base.keys(): + if self._predicate(k): + yield k + + +class PrefixViewStateDict(_BaseViewStateDict): + def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False): + super().__init__(base, mutate_base=mutate_base) + self._source_prefix = source_prefix + self._target_prefix = target_prefix + self._mapping: Dict[str, str] = {} + self._reverse: Dict[str, str] = {} + for k in base.keys(): + if not k.startswith(source_prefix): + continue + view_key = f"{target_prefix}{k[len(source_prefix):]}" + self._mapping[k] = view_key + self._reverse[view_key] = k + + def _resolve_base_key(self, key: str) -> Optional[str]: + return self._reverse.get(key) + + def _iter_base_keys(self) -> Iterable[str]: + return self._reverse.keys() + + +class RenameViewStateDict(_BaseViewStateDict): + def __init__( + self, + base: MutableMapping, + replace_prefix: Mapping[str, str], + filter_keys: bool = False, + mutate_base: bool = False, + ): + super().__init__(base, mutate_base=mutate_base) + self._filter_keys = filter_keys + self._replace = list(replace_prefix.items()) + self._mapping: Dict[str, str] = {} + self._reverse: Dict[str, str] = {} + for k in base.keys(): + view_key = self._replace_key(k) + if view_key is None: + continue + self._mapping[k] = view_key + self._reverse[view_key] = k + + def _replace_key(self, key: str) -> Optional[str]: + for rp, dst in self._replace: + if key.startswith(rp): + return f"{dst}{key[len(rp):]}" + if self._filter_keys: + return None + return key + + def _resolve_base_key(self, key: str) -> Optional[str]: + return self._reverse.get(key) + + def _iter_base_keys(self) -> Iterable[str]: + return self._reverse.keys() + + +class MergedStateDict(MutableMapping): + is_stream_state_dict = True + + def __init__(self, *mappings: MutableMapping): + self._mappings = list(mappings) + self._overrides: Dict[str, torch.Tensor] = {} + self._deleted: set[str] = set() + + def __getitem__(self, key: str) -> torch.Tensor: + if key in self._overrides: + return self._overrides[key] + if key in self._deleted: + raise KeyError(key) + for mapping in reversed(self._mappings): + if key in mapping: + if hasattr(mapping, "get_tensor"): + return mapping.get_tensor(key) + return mapping[key] + raise KeyError(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._deleted: + raise KeyError(key) + if any(key in mapping for mapping in self._mappings): + self._deleted.add(key) + return + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + seen = set() + for mapping in self._mappings: + for key in mapping.keys(): + if key in self._deleted or key in seen: + continue + seen.add(key) + yield key + for key in self._overrides.keys(): + if key not in seen: + yield key + + def __len__(self) -> int: + return len(list(self.__iter__())) + + def meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + if key in self._deleted: + raise KeyError(key) + for mapping in reversed(self._mappings): + if key in mapping: + if hasattr(mapping, "meta"): + return mapping.meta(key) + t = mapping[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + raise KeyError(key) + + +class MappedStateDict(_BaseViewStateDict): + def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False): + super().__init__(base, mutate_base=mutate_base) + self._base_to_view = dict(key_map) + self._view_to_base = {v: k for k, v in key_map.items()} + + def _resolve_base_key(self, key: str) -> Optional[str]: + return self._view_to_base.get(key) + + def _iter_base_keys(self) -> Iterable[str]: + return self._view_to_base.keys() diff --git a/comfy/sd.py b/comfy/sd.py index 5a7221620..4f4b7298b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -25,6 +25,8 @@ import math import os import comfy.utils +import comfy.safetensors_stream +import comfy.disk_weights from . import clip_vision from . import gligen @@ -124,7 +126,7 @@ class CLIP: if not model_management.supports_cast(load_device, dt): load_device = offload_device if params['device'] != offload_device: - self.cond_stage_model.to(offload_device) + comfy.disk_weights.module_to(self.cond_stage_model, offload_device) logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) @@ -288,7 +290,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False) else: return self.cond_stage_model.load_sd(sd) @@ -349,7 +351,7 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: - self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] + self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1] self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() @@ -364,25 +366,19 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - new_sd = {} - for k in sd: - new_sd["encoder.{}".format(k)] = sd[k] - sd = new_sd + sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."}) elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade self.first_stage_model = StageC_coder() self.latent_channels = 16 - new_sd = {} - for k in sd: - new_sd["previewer.{}".format(k)] = sd[k] - sd = new_sd + sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."}) elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 elif "decoder.conv_in.weight" in sd: - if sd['decoder.conv_in.weight'].shape[1] == 64: + if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1] self.downscale_ratio = 32 self.upscale_ratio = 32 self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -392,9 +388,9 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: + elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) @@ -417,7 +413,7 @@ class VAE: self.downscale_ratio = 4 self.upscale_ratio = 4 - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1] if 'decoder.post_quant_conv.weight' in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) @@ -430,7 +426,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1]) else: self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, @@ -465,11 +461,11 @@ class VAE: self.downscale_index_formula = (6, 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv - tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"] + tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight") version = 0 - if tensor_conv1.shape[0] == 512: + if tensor_conv1_shape[0] == 512: version = 0 - elif tensor_conv1.shape[0] == 1024: + elif tensor_conv1_shape[0] == 1024: version = 1 if "encoder.down_blocks.1.conv.conv.bias" in sd: version = 2 @@ -486,9 +482,9 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: + elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} - ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1] self.latent_channels = 32 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) @@ -512,8 +508,8 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1] + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1]) #This is likely to significantly over-estimate with single image or low frame counts as the #implementation is able to completely skip caching. Rework if used as an image only VAE self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) @@ -546,14 +542,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE - dim = sd["decoder.head.0.gamma"].shape[0] + dim = sd_shape(sd, "decoder.head.0.gamma")[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1] self.pad_channel_value = 1.0 ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) @@ -629,7 +625,7 @@ class VAE: self.working_dtypes = [torch.float32] self.crop_input = False elif "decoder.22.bias" in sd: # taehv, taew and lighttae - self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_channels = sd_shape(sd, "decoder.1.weight")[1] self.latent_dim = 3 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) @@ -640,12 +636,12 @@ class VAE: self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) - elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: - if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical latent_format=comfy.latent_formats.HunyuanVideo else: latent_format=None # lighttaew2_1 doesn't need scaling @@ -665,7 +661,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) + m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False) if len(m) > 0: logging.warning("Missing VAE keys {}".format(m)) @@ -679,7 +675,7 @@ class VAE: if dtype is None: dtype = model_management.vae_dtype(self.device, self.working_dtypes) self.vae_dtype = dtype - self.first_stage_model.to(self.vae_dtype) + comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype) self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) @@ -986,9 +982,12 @@ def load_style_model(ckpt_path): model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) - model.load_state_dict(model_data) + comfy.utils.load_state_dict(model, model_data, strict=True) return StyleModel(model) +def sd_shape(state_dict, key): + return comfy.utils.state_dict_meta(state_dict, key).shape + class CLIPType(Enum): STABLE_DIFFUSION = 1 STABLE_CASCADE = 2 @@ -1058,16 +1057,16 @@ def detect_te_model(sd): if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: - weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - if weight.shape[-1] == 4096: + weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight") + if weight_shape[-1] == 4096: return TEModel.T5_XXL - elif weight.shape[-1] == 2048: + elif weight_shape[-1] == 2048: return TEModel.T5_XL if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: - weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight'] - if weight.shape[0] == 384: + weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight') + if weight_shape[0] == 384: return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: @@ -1077,19 +1076,19 @@ def detect_te_model(sd): return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: - weight = sd['model.layers.0.self_attn.k_proj.bias'] - if weight.shape[0] == 256: + weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias') + if weight_shape[0] == 256: return TEModel.QWEN25_3B - if weight.shape[0] == 512: + if weight_shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: - weight = sd['model.layers.0.post_attention_layernorm.weight'] + weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight') if 'model.layers.0.self_attn.q_norm.weight' in sd: - if weight.shape[0] == 2560: + if weight_shape[0] == 2560: return TEModel.QWEN3_4B - elif weight.shape[0] == 2048: + elif weight_shape[0] == 2048: return TEModel.QWEN3_2B - if weight.shape[0] == 5120: + if weight_shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B else: @@ -1418,19 +1417,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c scaled_fp8_list.append(k[:-len("scaled_fp8")]) if len(scaled_fp8_list) > 0: - out_sd = {} - for k in sd: - skip = False + if comfy.utils.is_stream_state_dict(sd): + def _keep_key(k, prefixes=tuple(scaled_fp8_list)): + return not any(k.startswith(pref) for pref in prefixes) + out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False) + merged = out_sd for pref in scaled_fp8_list: - skip = skip or k.startswith(pref) - if not skip: - out_sd[k] = sd[k] + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd) + sd = merged + else: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] - for pref in scaled_fp8_list: - quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) - for k in quant_sd: - out_sd[k] = quant_sd[k] - sd = out_sd + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: @@ -1508,12 +1517,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - logging.warning("{} {}".format(diffusers_keys[k], k)) + if comfy.utils.is_stream_state_dict(sd): + new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys) + else: + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) @@ -1538,7 +1550,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) + model = comfy.disk_weights.module_to(model, offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..9fb8a8f5c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.transformer, sd, strict=False) def parse_parentheses(string): result = [] @@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No try: if embed_path.lower().endswith(".safetensors"): - import safetensors.torch - embed = safetensors.torch.load_file(embed_path, device="cpu") + embed = comfy.utils.load_torch_file(embed_path, safe_load=True) else: try: embed = torch.load(embed_path, weights_only=True, map_location="cpu") diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index ce36f1a84..7cf040588 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -56,9 +56,9 @@ class TAESD(nn.Module): self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) if encoder_path is not None: - self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) + comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True) if decoder_path is not None: - self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) + comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True) @staticmethod def scale_latents(x): diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 776e25e97..6d6e319dc 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module): if len(sdo) == 0: sdo = sd - return self.load_state_dict(sdo, strict=False) + return comfy.utils.load_state_dict(self, sdo, strict=False) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0 diff --git a/comfy/utils.py b/comfy/utils.py index ffa98c9b1..dda041625 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -26,10 +26,13 @@ import numpy as np from PIL import Image import logging import itertools +from types import SimpleNamespace from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args import json +from . import safetensors_stream +import comfy.disk_weights MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: - sd = {} - for k in f.keys(): - tensor = f.get_tensor(k) - if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues - tensor = tensor.to(device=device, copy=True) - sd[k] = tensor - if return_metadata: - metadata = f.metadata() + sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device) + if return_metadata: + metadata = sd.metadata() except Exception as e: if len(e.args) > 0: message = e.args[0] @@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): - w = sd[k] - params += w.nelement() + meta = state_dict_meta(sd, k) + params += meta.numel return params def weight_dtype(sd, prefix=""): dtypes = {} for k in sd.keys(): if k.startswith(prefix): - w = sd[k] - dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel() + meta = state_dict_meta(sd, k) + dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel if len(dtypes) == 0: return None @@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace): return state_dict def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if is_stream_state_dict(state_dict): + return safetensors_stream.RenameViewStateDict( + state_dict, + replace_prefix, + filter_keys=filter_keys, + mutate_base=not filter_keys, + ) if filter_keys: out = {} else: @@ -145,6 +149,79 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): return out +def is_stream_state_dict(state_dict) -> bool: + return getattr(state_dict, "is_stream_state_dict", False) + + +def state_dict_meta(state_dict, key): + if hasattr(state_dict, "meta"): + return state_dict.meta(key) + w = state_dict[key] + numel = w.numel() + return SimpleNamespace( + dtype=w.dtype, + shape=tuple(w.shape), + numel=numel, + nbytes=numel * w.element_size(), + ) + + +def load_state_dict(model, state_dict, strict=False, assign=False): + if is_stream_state_dict(state_dict): + if comfy.disk_weights.disk_weights_enabled(): + return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict) + comfy.disk_weights.register_module_weights(model, state_dict) + comfy.disk_weights.attach_disk_weight_hooks(model) + missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign) + return missing, unexpected + return model.load_state_dict(state_dict, strict=strict) + + +def stream_load_state_dict(model, state_dict, strict=False, assign=False): + if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"): + state_dict = state_dict.copy() + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + + def load(module, local_state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata["assign_to_params_buffers"] = assign + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = f"{prefix}{name}." + child_state_dict = safetensors_stream.FilterViewStateDict( + local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False + ) + load(child, child_state_dict, child_prefix) + incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible) + if out is not None: + raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.") + + load(model, state_dict) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs))) + return missing_keys, unexpected_keys + + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -825,7 +902,10 @@ def copy_to_param(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - prev.data.copy_(value) + if prev.device.type == "meta": + setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad)) + else: + prev.data.copy_(value) def get_attr(obj, attr: str): """Retrieves a nested attribute from an object using dot notation. @@ -1217,46 +1297,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): scaled_fp8_key = "{}scaled_fp8".format(model_prefix) if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict[scaled_fp8_key] - scaled_fp8_dtype = scaled_fp8_weight.dtype + if is_stream_state_dict(state_dict): + scaled_meta = state_dict_meta(state_dict, scaled_fp8_key) + scaled_fp8_dtype = scaled_meta.dtype + scaled_fp8_weight_nelements = scaled_meta.numel + else: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + scaled_fp8_weight_nelements = scaled_fp8_weight.nelement() if scaled_fp8_dtype == torch.float32: scaled_fp8_dtype = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: + if scaled_fp8_weight_nelements == 2: full_precision_matrix_mult = True else: full_precision_matrix_mult = False - out_sd = {} layers = {} - for k in list(state_dict.keys()): - if k == scaled_fp8_key: - continue - if not k.startswith(model_prefix): - out_sd[k] = state_dict[k] - continue - k_out = k - w = state_dict.pop(k) - layer = None - if k_out.endswith(".scale_weight"): - layer = k_out[:-len(".scale_weight")] - k_out = "{}.weight_scale".format(layer) - - if layer is not None: - layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints - if full_precision_matrix_mult: - layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult - layers[layer] = layer_conf - - if k_out.endswith(".scale_input"): - layer = k_out[:-len(".scale_input")] - k_out = "{}.input_scale".format(layer) - if w.item() == 1.0: + if is_stream_state_dict(state_dict): + key_map = {} + for k in list(state_dict.keys()): + if k == scaled_fp8_key: continue + if not k.startswith(model_prefix): + key_map[k] = k + continue + k_out = k + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) - out_sd[k_out] = w + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf - state_dict = out_sd + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + scale_val = state_dict[k] + if scale_val.item() == 1.0: + continue + + key_map[k] = k_out + state_dict = safetensors_stream.MappedStateDict(state_dict, key_map) + else: + out_sd = {} + for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd quant_metadata = {"layers": layers} else: quant_metadata = json.loads(metadata["_quantization_metadata"]) diff --git a/nodes.py b/nodes.py index 5a9d42d4a..a96f93e47 100644 --- a/nodes.py +++ b/nodes.py @@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np -import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -526,7 +525,7 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) - latent = safetensors.torch.load_file(latent_path, device="cpu") + latent = comfy.utils.load_torch_file(latent_path, safe_load=True) multiplier = 1.0 if "latent_format_version_0" not in latent: multiplier = 1.0 / 0.18215 diff --git a/requirements.txt b/requirements.txt index 7686a5f8a..11ec57e07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ transformers>=4.50.3 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 +fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip aiohttp>=3.11.8 yarl>=1.18.0 pyyaml diff --git a/tests-unit/utils/safetensors_stream_test.py b/tests-unit/utils/safetensors_stream_test.py new file mode 100644 index 000000000..60d36142d --- /dev/null +++ b/tests-unit/utils/safetensors_stream_test.py @@ -0,0 +1,182 @@ +import os + +import pytest +import importlib +import importlib.util + +torch = pytest.importorskip("torch") + + +def _write_safetensors(tmp_path, tensors): + import safetensors.torch + path = os.path.join(tmp_path, "test.safetensors") + safetensors.torch.save_file(tensors, path) + return path + + +def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + meta = sd.meta("a") + assert meta.shape == (2, 3) + assert meta.dtype == torch.float32 + assert meta.numel == 6 + assert calls == [] + + +def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + _ = sd["a"] + assert len(calls) == 1 + assert calls[0].shape == (2, 3) + + +def test_stream_views_do_not_materialize(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True) + _ = list(view.keys()) + assert calls == [] + + +def test_stream_load_rss_small(tmp_path): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + psutil = pytest.importorskip("psutil") + process = psutil.Process() + size_elems = 4_000_000 # ~16MB float32 + tensor = torch.zeros((size_elems,), dtype=torch.float32) + path = _write_safetensors(tmp_path, {"big": tensor}) + rss_before = process.memory_info().rss + sd = comfy.utils.load_torch_file(path, safe_load=True) + rss_after = process.memory_info().rss + expected_size = tensor.numel() * tensor.element_size() + assert (rss_after - rss_before) < expected_size + _ = sd.meta("big") + + +def test_gds_path_errors_without_support(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + device = torch.device("cuda") + + if importlib.util.find_spec("fastsafetensors") is None: + fst = None + else: + fst = importlib.import_module("fastsafetensors") + + gds_available = False + if fst is not None and torch.cuda.is_available(): + gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device()) + gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1 + + if not gds_available: + with pytest.raises(RuntimeError, match="GPUDirect requested"): + sd.get_tensor("a", device=device, allow_gds=True) + else: + def fail_nogds(*args, **kwargs): + raise AssertionError("nogds path used during GDS request") + + monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds) + t = sd.get_tensor("a", device=device, allow_gds=True) + assert t.device.type == "cuda" + + +def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + import comfy.disk_weights + + prev_cache = comfy.disk_weights.CACHE.max_bytes + prev_gds = comfy.disk_weights.ALLOW_GDS + prev_pin = comfy.disk_weights.PIN_IF_CPU + prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED + comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True) + + try: + path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + model = torch.nn.Linear(4, 4, bias=True) + comfy.utils.load_state_dict(model, sd, strict=False) + assert model.weight.device.type == "cpu" + assert model.weight.device.type != "meta" + finally: + comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled) + + +def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch): + if importlib.util.find_spec("fastsafetensors") is None: + pytest.skip("fastsafetensors not installed") + import comfy.utils + import comfy.disk_weights + + prev_cache = comfy.disk_weights.CACHE.max_bytes + prev_gds = comfy.disk_weights.ALLOW_GDS + prev_pin = comfy.disk_weights.PIN_IF_CPU + prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED + comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True) + + try: + path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + model = torch.nn.Linear(4, 4, bias=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + comfy.utils.load_state_dict(model, sd, strict=True) + assert model.weight.device.type == "meta" + assert calls == [] + + comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu")) + assert model.weight.device.type == "cpu" + assert len(calls) == 2 + finally: + comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)