diff --git a/DESIGN.md b/DESIGN.md
new file mode 100644
index 000000000..19a1bfc5d
--- /dev/null
+++ b/DESIGN.md
@@ -0,0 +1,57 @@
+# Disk tier safetensors streaming design audit (ComfyUI)
+
+## Mandatory research audit (verified call sites)
+
+### ComfyUI load path + eager materialization sites
+- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】
+- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】
+- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】
+- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】
+- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】
+- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】
+- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】
+
+### FastSageTensors (fastsafetensors) capability audit
+- Header parsing and metadata:
+ - `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】
+ - `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】
+- GDS + no-GDS low-level readers:
+ - `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】
+ - GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】
+- DLPack wrapping:
+ - `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】
+- Torch framework interop:
+ - `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】
+
+### VRAM/RAM offload logic (for extension)
+- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
+- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
+
+## Strategy summary (implemented)
+
+### Streaming safetensors mapping (no full dict materialization)
+- [x] Introduce a new module `comfy/safetensors_stream.py` with:
+ - [x] `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
+ - [x] `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
+ - [x] Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
+
+### Range reads and tiering
+- [x] Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
+- [x] Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
+- [x] Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
+
+### Disk tier integration
+- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
+- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
+- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
+
+### Pipeline refactors
+- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
+- [x] Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
+- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
+- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
+- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
+
+### Tests and docs
+- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
+- [x] Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
diff --git a/README.md b/README.md
index 6d09758c0..550482e5c 100644
--- a/README.md
+++ b/README.md
@@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
+| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
+| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
+
+### Disk tier for model weights
+
+When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
+
+If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
# Running
diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py
index 46ef21c95..bf3abf834 100644
--- a/comfy/audio_encoders/audio_encoders.py
+++ b/comfy/audio_encoders/audio_encoders.py
@@ -29,7 +29,7 @@ class AudioEncoderModel():
self.model_sample_rate = 16000
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 1716c3de7..5024caf8d 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
+parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
+parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
+
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index d5fc53497..837eba013 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -48,7 +48,7 @@ class ClipVisionModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 0b5e30f52..551f0bb18 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -25,6 +25,7 @@ import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
+import comfy.disk_weights
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
@@ -385,7 +386,7 @@ class ControlLora(ControlNet):
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
- self.control_model.to(comfy.model_management.get_torch_device())
+ comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
@@ -439,7 +440,7 @@ def controlnet_config(sd, model_options={}):
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
- missing, unexpected = control_model.load_state_dict(sd, strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@@ -473,9 +474,9 @@ def load_controlnet_mmdit(sd, model_options={}):
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
- missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
- missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
@@ -748,9 +749,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
pass
w = WeightsLoader()
w.control_model = control_model
- missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
else:
- missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@@ -816,8 +817,8 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
- self.t2i_model.to(x_noisy.dtype)
- self.t2i_model.to(self.device)
+ comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
+ comfy.disk_weights.module_to(self.t2i_model, self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
@@ -874,7 +875,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
else:
return None
- missing, unexpected = model_ad.load_state_dict(t2i_data)
+ missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))
diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py
new file mode 100644
index 000000000..07e9b5f11
--- /dev/null
+++ b/comfy/disk_weights.py
@@ -0,0 +1,1115 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+from __future__ import annotations
+
+import collections
+import logging
+import weakref
+from dataclasses import dataclass, field
+from types import SimpleNamespace
+from typing import Dict, MutableMapping, Optional, Set
+
+import torch
+
+from . import safetensors_stream
+
+
+ALLOW_GDS = False
+PIN_IF_CPU = False
+DISK_WEIGHTS_ENABLED = False
+BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
+LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
+DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
+_MISSING = object()
+
+
+@dataclass
+class DiskTensorRef:
+ state_dict: object
+ key: str
+ meta: object
+ requires_grad: bool
+ is_buffer: bool
+
+ def load(
+ self,
+ device: torch.device,
+ allow_gds: bool,
+ pin_if_cpu: bool,
+ dtype_override: Optional[torch.dtype] = None,
+ ) -> torch.Tensor:
+ dtype = dtype_override or getattr(self.meta, "dtype", None)
+ if hasattr(self.state_dict, "get_tensor"):
+ return self.state_dict.get_tensor(
+ self.key,
+ device=device,
+ dtype=dtype,
+ allow_gds=allow_gds,
+ pin_if_cpu=pin_if_cpu,
+ )
+ tensor = self.state_dict[self.key]
+ if device is not None and tensor.device != device:
+ tensor = tensor.to(device=device)
+ if dtype is not None and tensor.dtype != dtype:
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+
+
+class DiskWeightRegistry:
+ def __init__(self):
+ self._registry = weakref.WeakKeyDictionary()
+
+ def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef):
+ module_refs = self._registry.setdefault(module, {})
+ module_refs[name] = ref
+
+ def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]:
+ return self._registry.get(module)
+
+ def has(self, module: torch.nn.Module) -> bool:
+ return module in self._registry
+
+
+@dataclass
+class CacheEntry:
+ module_ref: weakref.ReferenceType
+ name: str
+ size_bytes: int
+ is_buffer: bool
+
+
+class DiskWeightCache:
+ def __init__(self, max_bytes: int = 0):
+ self.max_bytes = max_bytes
+ self.current_bytes = 0
+ self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict()
+
+ def set_limit(self, max_bytes: int):
+ self.max_bytes = max_bytes
+ self._evict_if_needed()
+
+ def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]:
+ return (id(module), name)
+
+ def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
+ if tensor.device.type != "cpu":
+ return
+ size_bytes = tensor.numel() * tensor.element_size()
+ key = self._entry_key(module, name)
+ if key in self._entries:
+ entry = self._entries.pop(key)
+ self.current_bytes -= entry.size_bytes
+ module_ref = weakref.ref(module, self._drop_module_entries)
+ self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
+ self.current_bytes += size_bytes
+ self._evict_if_needed()
+
+ def touch(self, module: torch.nn.Module, name: str):
+ key = self._entry_key(module, name)
+ if key in self._entries:
+ entry = self._entries.pop(key)
+ self._entries[key] = entry
+
+ def evict_bytes(self, bytes_to_free: int):
+ freed = 0
+ while self._entries and freed < bytes_to_free:
+ _, entry = self._entries.popitem(last=False)
+ freed += entry.size_bytes
+ self.current_bytes -= entry.size_bytes
+ module = entry.module_ref()
+ if module is not None:
+ _evict_module_weight(module, entry.name, entry.is_buffer)
+ return freed
+
+ def remove_module(self, module: torch.nn.Module):
+ to_remove = []
+ for key, entry in self._entries.items():
+ if entry.module_ref() is module:
+ to_remove.append(key)
+ for key in to_remove:
+ entry = self._entries.pop(key)
+ self.current_bytes -= entry.size_bytes
+
+ def _drop_module_entries(self, module_ref: weakref.ReferenceType):
+ to_remove = []
+ for key, entry in self._entries.items():
+ if entry.module_ref is module_ref:
+ to_remove.append(key)
+ for key in to_remove:
+ entry = self._entries.pop(key)
+ self.current_bytes -= entry.size_bytes
+
+ def _evict_if_needed(self):
+ while self._entries and self.current_bytes > self.max_bytes:
+ _, entry = self._entries.popitem(last=False)
+ self.current_bytes -= entry.size_bytes
+ module = entry.module_ref()
+ if module is not None:
+ _evict_module_weight(module, entry.name, entry.is_buffer)
+
+
+REGISTRY = DiskWeightRegistry()
+CACHE = DiskWeightCache(0)
+LOGGER = logging.getLogger(__name__)
+
+
+def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
+ global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
+ ALLOW_GDS = allow_gds
+ PIN_IF_CPU = pin_if_cpu
+ DISK_WEIGHTS_ENABLED = enabled
+ CACHE.set_limit(cache_bytes if enabled else 0)
+ if not enabled:
+ CACHE._entries.clear()
+ CACHE.current_bytes = 0
+
+
+def disk_weights_enabled() -> bool:
+ return DISK_WEIGHTS_ENABLED
+
+
+def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
+ if not disk_weights_enabled():
+ return
+ if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
+ return
+ for module_name, submodule in module.named_modules():
+ module_prefix = f"{prefix}{module_name}." if module_name else prefix
+ for name, param in submodule.named_parameters(recurse=False):
+ key = f"{module_prefix}{name}" if module_prefix else name
+ if key in state_dict:
+ meta = state_dict.meta(key)
+ ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
+ REGISTRY.register(submodule, name, ref)
+ if param.device.type == "cpu":
+ CACHE.record(submodule, name, param, is_buffer=False)
+ for name, buf in submodule.named_buffers(recurse=False):
+ key = f"{module_prefix}{name}" if module_prefix else name
+ if key in state_dict and buf is not None:
+ meta = state_dict.meta(key)
+ ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
+ REGISTRY.register(submodule, name, ref)
+ if buf.device.type == "cpu":
+ CACHE.record(submodule, name, buf, is_buffer=True)
+
+
+@dataclass
+class LazyModuleState:
+ state_dict: MutableMapping
+ prefix: str
+ loaded: bool = False
+
+
+@dataclass
+class DiskMaterializationState:
+ loaded_keys: Set[str] = field(default_factory=set)
+ deferred_keys: Set[str] = field(default_factory=set)
+ loaded_bytes: int = 0
+ deferred_bytes: int = 0
+ future_dtypes: Dict[str, torch.dtype] = field(default_factory=dict)
+
+
+def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState:
+ state = DISK_MATERIALIZATION_STATE.get(module)
+ if state is None:
+ state = DiskMaterializationState()
+ DISK_MATERIALIZATION_STATE[module] = state
+ return state
+
+
+def _set_future_dtype(module: torch.nn.Module, name: str, dtype: Optional[torch.dtype]):
+ state = _get_materialization_state(module)
+ if dtype is None:
+ state.future_dtypes.pop(name, None)
+ else:
+ state.future_dtypes[name] = dtype
+
+
+def _get_future_dtype(module: torch.nn.Module, name: str) -> Optional[torch.dtype]:
+ state = DISK_MATERIALIZATION_STATE.get(module)
+ if state is None:
+ return None
+ return state.future_dtypes.get(name)
+
+
+def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState):
+ module.disk_loaded_weight_memory = state.loaded_bytes
+ module.disk_offload_buffer_memory = state.deferred_bytes
+
+
+def _tensor_nbytes(tensor: torch.Tensor) -> int:
+ return tensor.numel() * tensor.element_size()
+
+
+def _meta_nbytes(meta) -> Optional[int]:
+ return getattr(meta, "nbytes", None)
+
+
+def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Tensor:
+ dtype = dtype_override or getattr(meta, "dtype", None)
+ shape = getattr(meta, "shape", None)
+ if dtype is None or shape is None:
+ raise KeyError("Missing metadata for meta tensor")
+ return torch.empty(shape, dtype=dtype, device="meta")
+
+
+def _state_dict_meta(state_dict: MutableMapping, key: str):
+ if hasattr(state_dict, "meta"):
+ return state_dict.meta(key)
+ if hasattr(state_dict, "get_tensor"):
+ t = state_dict.get_tensor(key, device=torch.device("meta"))
+ else:
+ t = state_dict[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+
+
+def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, DiskTensorRef], state: DiskMaterializationState):
+ state.loaded_keys.clear()
+ state.deferred_keys.clear()
+ state.loaded_bytes = 0
+ state.deferred_bytes = 0
+ for name, ref in refs.items():
+ if name in module._parameters:
+ tensor = module._parameters[name]
+ elif name in module._buffers:
+ tensor = module._buffers[name]
+ else:
+ continue
+ if tensor is None:
+ continue
+ nbytes = _meta_nbytes(ref.meta) or _tensor_nbytes(tensor)
+ if tensor.device.type == "meta":
+ state.deferred_keys.add(name)
+ state.deferred_bytes += nbytes
+ else:
+ state.loaded_keys.add(name)
+ state.loaded_bytes += nbytes
+ _update_disk_state_attrs(module, state)
+
+
+def _summarize_module_bytes(module: torch.nn.Module, refs: Dict[str, DiskTensorRef]):
+ cpu_bytes = 0
+ gpu_bytes = 0
+ meta_bytes = 0
+ total_bytes = 0
+ for name, ref in refs.items():
+ tensor = None
+ if name in module._parameters:
+ tensor = module._parameters[name]
+ elif name in module._buffers:
+ tensor = module._buffers[name]
+ if tensor is None:
+ continue
+ nbytes = _meta_nbytes(ref.meta)
+ if nbytes is None:
+ nbytes = _tensor_nbytes(tensor)
+ total_bytes += nbytes
+ if tensor.device.type == "meta":
+ meta_bytes += nbytes
+ elif tensor.device.type == "cpu":
+ cpu_bytes += nbytes
+ else:
+ gpu_bytes += nbytes
+ return total_bytes, cpu_bytes, gpu_bytes, meta_bytes
+
+
+def _log_materialization(
+ module: torch.nn.Module,
+ target_device: torch.device,
+ free_mem: int,
+ refs: Dict[str, DiskTensorRef],
+ state: DiskMaterializationState,
+ context: str,
+):
+ total_bytes, cpu_bytes, gpu_bytes, meta_bytes = _summarize_module_bytes(module, refs)
+ if total_bytes == 0:
+ return
+ partial = meta_bytes > 0
+ LOGGER.info(
+ "%s: module=%s dest=%s load=%0.2fMB free=%0.2fMB partial=%s "
+ "loaded=%0.2fMB meta=%0.2fMB cpu=%0.2fMB gpu=%0.2fMB full_load=%s",
+ context,
+ module.__class__.__name__,
+ target_device,
+ total_bytes / (1024 * 1024),
+ free_mem / (1024 * 1024),
+ partial,
+ state.loaded_bytes / (1024 * 1024),
+ state.deferred_bytes / (1024 * 1024),
+ cpu_bytes / (1024 * 1024),
+ gpu_bytes / (1024 * 1024),
+ not partial,
+ )
+
+
+def _device_free_memory(device: torch.device) -> int:
+ from . import model_management
+ return int(model_management.get_free_memory(device))
+
+
+def _evict_ram_for_budget(required_bytes: int) -> int:
+ if required_bytes <= 0:
+ return 0
+ freed = evict_ram_cache(required_bytes)
+ if freed < required_bytes:
+ from . import model_management
+ freed += model_management.evict_ram_to_disk(required_bytes - freed)
+ return freed
+
+
+def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
+ free_mem = _device_free_memory(device)
+ if device.type == "cpu" and free_mem < required_bytes:
+ _evict_ram_for_budget(required_bytes - free_mem)
+ free_mem = _device_free_memory(device)
+ return free_mem
+
+
+def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
+ from . import model_management
+ if device.type == "cpu":
+ alt = model_management.get_torch_device()
+ if alt.type != "cpu":
+ return alt
+ else:
+ return torch.device("cpu")
+ return None
+
+
+class _BudgetedStateDict(MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(
+ self,
+ base: MutableMapping,
+ allowed_keys: Set[str],
+ device: torch.device,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ dtype_override: Optional[torch.dtype] = None,
+ overrides: Optional[Dict[str, torch.Tensor]] = None,
+ ):
+ self._base = base
+ self._allowed_keys = allowed_keys
+ self._device = device
+ self._allow_gds = allow_gds
+ self._pin_if_cpu = pin_if_cpu
+ self._dtype_override = dtype_override
+ self._overrides = overrides or {}
+ self._deleted: Set[str] = set()
+
+ def _get_meta(self, key: str):
+ if key in self._overrides:
+ t = self._overrides[key]
+ return safetensors_stream.TensorMeta(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=t.numel(),
+ nbytes=_tensor_nbytes(t),
+ data_offsets=(0, _tensor_nbytes(t)),
+ filename="",
+ fst_dtype=None,
+ strides=tuple(t.stride()),
+ )
+ if hasattr(self._base, "meta"):
+ return self._base.meta(key)
+ if hasattr(self._base, "get_tensor"):
+ t = self._base.get_tensor(key, device=torch.device("meta"))
+ else:
+ t = self._base[key]
+ return safetensors_stream.TensorMeta(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=t.numel(),
+ nbytes=_tensor_nbytes(t),
+ data_offsets=(0, _tensor_nbytes(t)),
+ filename="",
+ fst_dtype=None,
+ strides=tuple(t.stride()),
+ )
+
+ def get_tensor(
+ self,
+ key: str,
+ *,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ ) -> torch.Tensor:
+ requested_dtype = dtype if dtype is not None else self._dtype_override
+ if key in self._overrides:
+ t = self._overrides[key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if requested_dtype is not None and t.dtype != requested_dtype:
+ t = t.to(dtype=requested_dtype)
+ return t
+ if key in self._deleted:
+ raise KeyError(key)
+ if key not in self._allowed_keys:
+ meta = self._get_meta(key)
+ target_dtype = requested_dtype or meta.dtype
+ return _meta_tensor(meta, dtype_override=target_dtype)
+ if hasattr(self._base, "get_tensor"):
+ return self._base.get_tensor(
+ key,
+ device=self._device if device is None else device,
+ dtype=requested_dtype,
+ allow_gds=self._allow_gds if allow_gds is None else allow_gds,
+ pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu,
+ )
+ t = self._base[key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if requested_dtype is not None and t.dtype != requested_dtype:
+ t = t.to(dtype=requested_dtype)
+ return t
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ return self.get_tensor(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ if key in self._deleted:
+ raise KeyError(key)
+ self._deleted.add(key)
+
+ def __iter__(self):
+ for k in self._base.keys():
+ if k in self._deleted:
+ continue
+ yield k
+ for k in self._overrides.keys():
+ if k not in self._deleted:
+ yield k
+
+ def __len__(self) -> int:
+ base_keys = list(self._base.keys())
+ return len(base_keys) - len(self._deleted) + len(self._overrides)
+
+ def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides.pop(key)
+ if key in self._deleted:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ if key not in self._base:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ self._deleted.add(key)
+ return self.get_tensor(key)
+
+ def meta(self, key: str):
+ return self._get_meta(key)
+
+def _has_custom_load(module: torch.nn.Module) -> bool:
+ return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT
+
+
+def register_lazy_modules(model: torch.nn.Module, state_dict):
+ if not hasattr(state_dict, "keys"):
+ return
+ for name, module in model.named_modules():
+ if not _has_custom_load(module):
+ continue
+ prefix = f"{name}." if name else ""
+ if prefix:
+ has_key = False
+ for param_name in module._parameters.keys():
+ if f"{prefix}{param_name}" in state_dict:
+ has_key = True
+ break
+ if not has_key:
+ for buf_name in module._buffers.keys():
+ if f"{prefix}{buf_name}" in state_dict:
+ has_key = True
+ break
+ if not has_key:
+ continue
+ view = safetensors_stream.FilterViewStateDict(
+ state_dict, lambda k, p=prefix: k.startswith(p), mutate_base=False
+ )
+ LAZY_MODULE_STATE[module] = LazyModuleState(state_dict=view, prefix=prefix)
+
+
+def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
+ lazy_state = LAZY_MODULE_STATE.get(module)
+ if lazy_state is not None:
+ CACHE.remove_module(module)
+ refs = REGISTRY.get(module)
+ if refs:
+ state = _get_materialization_state(module)
+ for ref_name, disk_ref in refs.items():
+ shape = getattr(disk_ref.meta, "shape", None)
+ dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
+ if shape is None or dtype is None:
+ continue
+ meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
+ if disk_ref.is_buffer:
+ module._buffers[ref_name] = meta_tensor
+ else:
+ module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
+ nbytes = _meta_nbytes(disk_ref.meta)
+ if nbytes is not None:
+ state.loaded_keys.discard(ref_name)
+ if ref_name not in state.deferred_keys:
+ state.deferred_keys.add(ref_name)
+ state.deferred_bytes += nbytes
+ state.loaded_bytes = max(0, state.loaded_bytes - nbytes)
+ _update_disk_state_attrs(module, state)
+ lazy_state.loaded = False
+ return
+ ref = REGISTRY.get(module)
+ if not ref or name not in ref:
+ return
+ disk_ref = ref[name]
+ shape = getattr(disk_ref.meta, "shape", None)
+ dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
+ if shape is None or dtype is None:
+ return
+ meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
+ if is_buffer:
+ module._buffers[name] = meta_tensor
+ else:
+ module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
+ state = _get_materialization_state(module)
+ nbytes = _meta_nbytes(disk_ref.meta)
+ if nbytes is not None:
+ state.loaded_keys.discard(name)
+ if name not in state.deferred_keys:
+ state.deferred_keys.add(name)
+ state.deferred_bytes += nbytes
+ state.loaded_bytes = max(0, state.loaded_bytes - nbytes)
+ _update_disk_state_attrs(module, state)
+
+
+def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
+ def check(obj):
+ if torch.is_tensor(obj):
+ return obj.device
+ if isinstance(obj, (list, tuple)):
+ for item in obj:
+ dev = check(item)
+ if dev is not None:
+ return dev
+ if isinstance(obj, dict):
+ for item in obj.values():
+ dev = check(item)
+ if dev is not None:
+ return dev
+ return None
+
+ dev = check(args)
+ if dev is not None:
+ return dev
+ return check(kwargs)
+
+
+def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]:
+ def check(obj):
+ if torch.is_tensor(obj):
+ return obj.dtype
+ if isinstance(obj, (list, tuple)):
+ for item in obj:
+ dtype = check(item)
+ if dtype is not None:
+ return dtype
+ if isinstance(obj, dict):
+ for item in obj.values():
+ dtype = check(item)
+ if dtype is not None:
+ return dtype
+ return None
+
+ dtype = check(args)
+ if dtype is not None:
+ return dtype
+ return check(kwargs)
+
+
+def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: Optional[torch.dtype]) -> Optional[torch.dtype]:
+ if manual_cast_dtype is not None:
+ return manual_cast_dtype
+ if input_dtype is None:
+ return None
+ if torch.is_floating_point(torch.empty((), dtype=input_dtype)):
+ return input_dtype
+ return None
+
+
+def ensure_module_materialized(
+ module: torch.nn.Module,
+ target_device: torch.device,
+ fallback_device: Optional[torch.device] = None,
+ dtype_override: Optional[torch.dtype] = None,
+):
+ lazy_state = LAZY_MODULE_STATE.get(module)
+ if lazy_state is not None:
+ _materialize_module_from_state_dict(
+ module,
+ lazy_state,
+ target_device,
+ dtype_override=dtype_override,
+ )
+ return
+ refs = REGISTRY.get(module)
+ if not refs:
+ return
+ state = _get_materialization_state(module)
+ if dtype_override is not None:
+ for name in refs.keys():
+ _set_future_dtype(module, name, dtype_override)
+ _rebuild_materialization_state(module, refs, state)
+ free_mem_start = _device_free_memory(target_device)
+ remaining_budget = free_mem_start
+ for name in sorted(refs.keys()):
+ disk_ref = refs[name]
+ if name in module._parameters:
+ current = module._parameters[name]
+ is_buffer = False
+ elif name in module._buffers:
+ current = module._buffers[name]
+ is_buffer = True
+ else:
+ continue
+ if current is None:
+ continue
+ target_dtype = dtype_override or _get_future_dtype(module, name)
+ if current.device.type != "meta" and current.device == target_device and (
+ target_dtype is None or current.dtype == target_dtype
+ ):
+ if current.device.type == "cpu":
+ CACHE.touch(module, name)
+ continue
+ meta_nbytes = _meta_nbytes(disk_ref.meta)
+ if meta_nbytes is None:
+ continue
+ required_bytes = meta_nbytes
+ if target_device.type == "cpu":
+ free_mem = _maybe_free_ram_budget(target_device, required_bytes)
+ remaining_budget = min(remaining_budget, free_mem)
+ if required_bytes > remaining_budget:
+ if fallback_device is not None and fallback_device != target_device:
+ fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
+ if fallback_free >= required_bytes:
+ target_for_load = fallback_device
+ else:
+ continue
+ else:
+ continue
+ else:
+ target_for_load = target_device
+ if current.device.type == "meta":
+ tensor = disk_ref.load(
+ target_for_load,
+ ALLOW_GDS,
+ PIN_IF_CPU,
+ dtype_override=target_dtype,
+ )
+ else:
+ if target_dtype is not None and current.dtype != target_dtype:
+ tensor = current.to(device=target_for_load, dtype=target_dtype)
+ else:
+ tensor = current.to(device=target_for_load)
+ if is_buffer:
+ module._buffers[name] = tensor
+ else:
+ module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
+ if tensor.device.type == "cpu":
+ CACHE.record(module, name, tensor, is_buffer=is_buffer)
+ remaining_budget = max(0, remaining_budget - required_bytes)
+ _rebuild_materialization_state(module, refs, state)
+ _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
+
+
+def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
+ if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
+ return
+ input_dtype = _find_tensor_dtype(args, kwargs)
+ manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
+ dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
+ if getattr(module, "comfy_cast_weights", False):
+ target_device = torch.device("cpu")
+ fallback_device = _find_tensor_device(args, kwargs)
+ else:
+ target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
+ fallback_device = None
+ ensure_module_materialized(
+ module,
+ target_device,
+ fallback_device=fallback_device,
+ dtype_override=dtype_override,
+ )
+
+
+def attach_disk_weight_hooks(model: torch.nn.Module):
+ if not disk_weights_enabled():
+ return
+ for module in model.modules():
+ if getattr(module, "_disk_weight_hook_attached", False):
+ continue
+ module.register_forward_pre_hook(disk_weight_pre_hook)
+ module._disk_weight_hook_attached = True
+
+
+def evict_ram_cache(bytes_to_free: int):
+ if bytes_to_free <= 0:
+ return 0
+ return CACHE.evict_bytes(bytes_to_free)
+
+
+def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
+ if not disk_weights_enabled():
+ return
+ for submodule in module.modules():
+ ensure_module_materialized(submodule, target_device)
+
+
+def _extract_to_device(args, kwargs) -> Optional[torch.device]:
+ if "device" in kwargs and kwargs["device"] is not None:
+ return torch.device(kwargs["device"])
+ for arg in args:
+ if isinstance(arg, torch.device):
+ return arg
+ if isinstance(arg, str):
+ return torch.device(arg)
+ return None
+
+
+def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]:
+ if "dtype" in kwargs and kwargs["dtype"] is not None:
+ return kwargs["dtype"]
+ for arg in args:
+ if isinstance(arg, torch.dtype):
+ return arg
+ return None
+
+
+def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
+ for param in module.parameters(recurse=True):
+ if param is not None and param.device.type != "meta":
+ return param.device
+ for buf in module.buffers(recurse=True):
+ if buf is not None and buf.device.type != "meta":
+ return buf.device
+ return None
+
+
+def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
+ def _move(tensor):
+ if tensor is None:
+ return None
+ if tensor.device.type == "meta":
+ return tensor
+ if dtype_override is not None and tensor.dtype != dtype_override:
+ return tensor.to(device=device_to, dtype=dtype_override)
+ return tensor.to(device=device_to)
+
+ module._apply(_move)
+ return module
+
+
+def offload_module_weights(module: torch.nn.Module) -> int:
+ if not disk_weights_enabled():
+ return 0
+ refs = REGISTRY.get(module)
+ if not refs:
+ return 0
+ offloaded_bytes = 0
+ if module in LAZY_MODULE_STATE:
+ ref_name = next(iter(refs.keys()), None)
+ if ref_name is not None:
+ _evict_module_weight(module, ref_name, False)
+ for disk_ref in refs.values():
+ nbytes = _meta_nbytes(disk_ref.meta)
+ if nbytes is not None:
+ offloaded_bytes += nbytes
+ return offloaded_bytes
+ for name, disk_ref in refs.items():
+ _evict_module_weight(module, name, disk_ref.is_buffer)
+ nbytes = _meta_nbytes(disk_ref.meta)
+ if nbytes is not None:
+ offloaded_bytes += nbytes
+ return offloaded_bytes
+
+
+def module_to(module: torch.nn.Module, *args, **kwargs):
+ allow_materialize = kwargs.pop("allow_materialize", True)
+ if disk_weights_enabled():
+ target_device = _extract_to_device(args, kwargs)
+ if target_device is None:
+ target_device = _find_existing_device(module) or torch.device("cpu")
+ if target_device.type == "meta":
+ offload_module_weights(module)
+ return module
+ if allow_materialize:
+ materialize_module_tree(module, target_device)
+ return module.to(*args, **kwargs)
+ dtype_override = _extract_to_dtype(args, kwargs)
+ return move_module_tensors(module, target_device, dtype_override=dtype_override)
+ return module.to(*args, **kwargs)
+
+
+def load_module_tensor(
+ module: torch.nn.Module,
+ name: str,
+ device: torch.device,
+ *,
+ allow_alternate: bool = True,
+ record_cache: bool = True,
+ temporary: bool = False,
+ dtype_override: Optional[torch.dtype] = None,
+) -> Optional[torch.Tensor]:
+ refs = REGISTRY.get(module)
+ if not refs or name not in refs:
+ return None
+ if name in module._parameters:
+ current = module._parameters[name]
+ is_buffer = False
+ elif name in module._buffers:
+ current = module._buffers[name]
+ is_buffer = True
+ else:
+ return None
+ if current is None:
+ return None
+ target_dtype = dtype_override or _get_future_dtype(module, name)
+ if dtype_override is not None:
+ _set_future_dtype(module, name, dtype_override)
+ if current.device.type != "meta":
+ if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
+ if target_dtype is not None and current.dtype != target_dtype:
+ tensor = current.to(device=device, dtype=target_dtype)
+ else:
+ tensor = current.to(device=device)
+ if not temporary:
+ if is_buffer:
+ module._buffers[name] = tensor
+ else:
+ module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=refs[name].requires_grad)
+ _rebuild_materialization_state(module, refs, _get_materialization_state(module))
+ return tensor
+ return current
+
+ disk_ref = refs[name]
+ required_bytes = _meta_nbytes(disk_ref.meta)
+ if required_bytes is None:
+ return current
+ free_mem_start = _device_free_memory(device)
+ free_mem = _maybe_free_ram_budget(device, required_bytes)
+ load_device = device
+ if free_mem < required_bytes and allow_alternate:
+ alt = _choose_alternate_device(device)
+ if alt is not None:
+ alt_free = _maybe_free_ram_budget(alt, required_bytes)
+ if alt_free >= required_bytes:
+ load_device = alt
+ else:
+ state = _get_materialization_state(module)
+ if name not in state.deferred_keys:
+ state.deferred_keys.add(name)
+ state.deferred_bytes += required_bytes
+ _update_disk_state_attrs(module, state)
+ _log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
+ return current
+ else:
+ state = _get_materialization_state(module)
+ if name not in state.deferred_keys:
+ state.deferred_keys.add(name)
+ state.deferred_bytes += required_bytes
+ _update_disk_state_attrs(module, state)
+ _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
+ return current
+ elif free_mem < required_bytes:
+ state = _get_materialization_state(module)
+ if name not in state.deferred_keys:
+ state.deferred_keys.add(name)
+ state.deferred_bytes += required_bytes
+ _update_disk_state_attrs(module, state)
+ _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
+ return current
+
+ tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
+ if temporary:
+ return tensor
+ if is_buffer:
+ module._buffers[name] = tensor
+ else:
+ module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
+ if tensor.device.type == "cpu" and record_cache:
+ CACHE.record(module, name, tensor, is_buffer=is_buffer)
+ state = _get_materialization_state(module)
+ _rebuild_materialization_state(module, refs, state)
+ _log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
+ return tensor
+
+
+def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool):
+ parts = name.split(".")
+ module = model
+ for part in parts[:-1]:
+ module = getattr(module, part)
+ attr = parts[-1]
+ if is_buffer:
+ module._buffers[attr] = tensor
+ else:
+ module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
+
+
+def _materialize_module_from_state_dict(
+ module: torch.nn.Module,
+ lazy_state: LazyModuleState,
+ target_device: torch.device,
+ dtype_override: Optional[torch.dtype] = None,
+):
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ metadata = getattr(lazy_state.state_dict, "_metadata", None)
+ local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
+ refs = REGISTRY.get(module) or {}
+ if dtype_override is not None:
+ for name in refs.keys():
+ _set_future_dtype(module, name, dtype_override)
+ state = _get_materialization_state(module)
+ _rebuild_materialization_state(module, refs, state)
+ keys = sorted(lazy_state.state_dict.keys())
+ existing = {}
+ for name, param in module.named_parameters(recurse=False):
+ key = f"{lazy_state.prefix}{name}"
+ if key in lazy_state.state_dict and param is not None and param.device.type != "meta":
+ existing[key] = param
+ for name, buf in module.named_buffers(recurse=False):
+ key = f"{lazy_state.prefix}{name}"
+ if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
+ existing[key] = buf
+ free_mem_start = _device_free_memory(target_device)
+ remaining_budget = free_mem_start
+ allowed = set(existing.keys())
+ for key in keys:
+ if key in allowed:
+ continue
+ meta = _state_dict_meta(lazy_state.state_dict, key)
+ required = _meta_nbytes(meta)
+ if required is None:
+ continue
+ if target_device.type == "cpu":
+ free_mem = _maybe_free_ram_budget(target_device, required)
+ remaining_budget = min(remaining_budget, free_mem)
+ if required <= remaining_budget:
+ allowed.add(key)
+ remaining_budget = max(0, remaining_budget - required)
+ deferred_state_dict_keys = {key for key in keys if key not in allowed}
+ state_dict = _BudgetedStateDict(
+ lazy_state.state_dict,
+ allowed_keys=allowed,
+ device=target_device,
+ allow_gds=ALLOW_GDS,
+ pin_if_cpu=PIN_IF_CPU,
+ dtype_override=dtype_override,
+ overrides=existing,
+ )
+ factory_device = None
+ if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs:
+ factory_device = module.factory_kwargs["device"]
+ module.factory_kwargs["device"] = target_device
+ try:
+ module._load_from_state_dict(
+ state_dict,
+ lazy_state.prefix,
+ local_metadata,
+ False,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
+ for hook in module._load_state_dict_post_hooks.values():
+ out = hook(module, incompatible)
+ if out is not None:
+ raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
+ finally:
+ if factory_device is not None:
+ module.factory_kwargs["device"] = factory_device
+ if len(error_msgs) > 0:
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
+ _rebuild_materialization_state(module, refs, state)
+ lazy_state.loaded = len(deferred_state_dict_keys) == 0
+ _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
+ for name, param in module.named_parameters(recurse=False):
+ if param.device.type == "cpu":
+ CACHE.record(module, name, param, is_buffer=False)
+ for name, buf in module.named_buffers(recurse=False):
+ if buf is not None and buf.device.type == "cpu":
+ CACHE.record(module, name, buf, is_buffer=True)
+
+
+def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = False):
+ model_keys = set()
+ for name, _ in model.named_parameters(recurse=True):
+ model_keys.add(name)
+ for name, _ in model.named_buffers(recurse=True):
+ model_keys.add(name)
+
+ state_keys = set(state_dict.keys())
+ missing_keys = [k for k in model_keys if k not in state_keys]
+ unexpected_keys = [k for k in state_keys if k not in model_keys]
+
+ if strict:
+ error_msgs = []
+ if len(unexpected_keys) > 0:
+ error_msgs.append('Unexpected key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in unexpected_keys)))
+ if len(missing_keys) > 0:
+ error_msgs.append('Missing key(s) in state_dict: {}.'.format(', '.join(f'"{k}"' for k in missing_keys)))
+ if error_msgs:
+ raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs)))
+
+ for name, param in model.named_parameters(recurse=True):
+ if name not in state_keys:
+ continue
+ meta = state_dict.meta(name)
+ meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
+ _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
+
+ for name, buf in model.named_buffers(recurse=True):
+ if buf is None or name not in state_keys:
+ continue
+ meta = state_dict.meta(name)
+ meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
+ _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
+
+ register_module_weights(model, state_dict)
+ register_lazy_modules(model, state_dict)
+ attach_disk_weight_hooks(model)
+ return missing_keys, unexpected_keys
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 1d7b6c2f4..c2cf7c6db 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -3,6 +3,7 @@ import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
+import comfy.utils
ops = comfy.ops.manual_cast
@@ -282,7 +283,7 @@ def load_gligen(sd):
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
- gated.load_state_dict(n_sd, strict=False)
+ comfy.utils.load_state_dict(gated, n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
@@ -293,7 +294,7 @@ def load_gligen(sd):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
- w.load_state_dict(sd, strict=False)
+ comfy.utils.load_state_dict(w, sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
index 51b6d1da8..19f0dafb4 100644
--- a/comfy/ldm/hunyuan_video/upsampler.py
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -1,4 +1,5 @@
import torch
+import comfy.utils
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
@@ -112,7 +113,7 @@ class HunyuanVideo15SRModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=True)
+ return comfy.utils.load_state_dict(self.model, sd, strict=True)
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py
index a9111d3bd..c61883ba3 100644
--- a/comfy/ldm/lightricks/vae/audio_vae.py
+++ b/comfy/ldm/lightricks/vae/audio_vae.py
@@ -2,6 +2,7 @@ import json
from dataclasses import dataclass
import math
import torch
+import comfy.utils
import torchaudio
import comfy.model_management
@@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module):
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
- self.autoencoder.load_state_dict(vae_sd, strict=False)
- self.vocoder.load_state_dict(vocoder_sd, strict=False)
+ comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
+ comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py
index 62f24606c..1009d2d76 100644
--- a/comfy/ldm/mmaudio/vae/vae.py
+++ b/comfy/ldm/mmaudio/vae/vae.py
@@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
+import comfy.utils
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
@@ -152,7 +153,7 @@ class VAE(nn.Module):
return dec, posterior
def load_weights(self, src_dict) -> None:
- self.load_state_dict(src_dict, strict=True)
+ comfy.utils.load_state_dict(self, src_dict, strict=True)
@property
def device(self) -> torch.device:
@@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')
-
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 49efd700b..84766591b 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -56,6 +56,7 @@ import comfy.conds
import comfy.ops
from enum import Enum
from . import utils
+from . import safetensors_stream
import comfy.latent_formats
import comfy.model_sampling
import math
@@ -299,20 +300,21 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix=""):
- to_load = {}
- keys = list(sd.keys())
- for k in keys:
- if k.startswith(unet_prefix):
- to_load[k[len(unet_prefix):]] = sd.pop(k)
-
+ replace_prefix = {unet_prefix: ""} if unet_prefix else {}
+ if replace_prefix:
+ if utils.is_stream_state_dict(sd):
+ to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True)
+ else:
+ to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False)
+ else:
+ to_load = sd
to_load = self.model_config.process_unet_state_dict(to_load)
- m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
+ m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
if len(u) > 0:
logging.warning("unet unexpected: {}".format(u))
- del to_load
return self
def process_latent_in(self, latent):
@@ -751,8 +753,8 @@ class StableAudio1(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
- self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
- self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
+ utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
+ utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
def extra_conds(self, **kwargs):
out = {}
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 0853b3aec..0b3a05659 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
+def sd_shape(state_dict, key):
+ return comfy.utils.state_dict_meta(state_dict, key).shape
+
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
- context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
- use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
+ context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1]
+ use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
@@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
unet_config = {}
- unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
- patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
+ unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1]
+ patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2]
unet_config["patch_size"] = patch_size
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
if final_layer in state_dict:
- unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
+ unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size)
- unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
+ unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64
unet_config["input_size"] = None
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
if y_key in state_dict_keys:
- unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
+ unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1]
context_key = '{}context_embedder.weight'.format(key_prefix)
if context_key in state_dict_keys:
- in_features = state_dict[context_key].shape[1]
- out_features = state_dict[context_key].shape[0]
+ in_features = sd_shape(state_dict, context_key)[1]
+ out_features = sd_shape(state_dict, context_key)[0]
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
num_patches_key = '{}pos_embed'.format(key_prefix)
if num_patches_key in state_dict_keys:
- num_patches = state_dict[num_patches_key].shape[1]
+ num_patches = sd_shape(state_dict, num_patches_key)[1]
unet_config["num_patches"] = num_patches
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
@@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
- w = state_dict[text_mapper_name]
- if w.shape[0] == 1536: #stage c lite
+ w_shape = sd_shape(state_dict, text_mapper_name)
+ if w_shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
- elif w.shape[0] == 2048: #stage c full
+ elif w_shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
- w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
- if w.shape[-1] == 640:
+ w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix))
+ if w_shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
- elif w.shape[-1] == 576: #stage b lite
+ elif w_shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
@@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
- unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
- unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
+ unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1]
+ unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
@@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
- unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
+ unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
- if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
+ if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
@@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
- in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
- out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
+ in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix))
+ out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix))
dit_config["image_model"] = "hunyuan_video"
- dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
- dit_config["patch_size"] = list(in_w.shape[2:])
- dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
+ dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels
+ dit_config["patch_size"] = list(in_w_shape[2:])
+ dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
@@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["meanflow"] = False
- dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
- dit_config["hidden_size"] = in_w.shape[0]
+ dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1]
+ dit_config["hidden_size"] = in_w_shape[0]
dit_config["mlp_ratio"] = 4.0
- dit_config["num_heads"] = in_w.shape[0] // 128
+ dit_config["num_heads"] = in_w_shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["theta"] = 256
@@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
- dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
+ dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
@@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
- w = state_dict[in_key]
- dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
- dit_config["hidden_size"] = w.shape[0]
+ w_shape = sd_shape(state_dict, in_key)
+ dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size)
+ dit_config["hidden_size"] = w_shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
- w = state_dict[txt_in_key]
- dit_config["context_in_dim"] = w.shape[1]
- dit_config["hidden_size"] = w.shape[0]
+ w_shape = sd_shape(state_dict, txt_in_key)
+ dit_config["context_in_dim"] = w_shape[1]
+ dit_config["hidden_size"] = w_shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
- dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
+ dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1]
else:
dit_config["vec_in_dim"] = None
@@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
- shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
+ shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix))
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
@@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
- dit_config["model_max_length"] = state_dict[y_key].shape[0]
+ dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
- dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
+ dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
@@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
- dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
+ dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
- dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
+ dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
@@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
- w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
- dit_config["dim"] = w.shape[0]
- dit_config["cap_feat_dim"] = w.shape[1]
+ w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix))
+ dit_config["dim"] = w_shape[0]
+ dit_config["cap_feat_dim"] = w_shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
@@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
- ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
- if ctd_weight is not None: # NewBie
- dit_config["clip_text_dim"] = ctd_weight.shape[0]
+ ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix)
+ if ctd_key in state_dict_keys: # NewBie
+ dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
@@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
- dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
- out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
+ dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1]
+ out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
- dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
+ dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
@@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
- dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
+ dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1]
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "vace"
- dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
+ dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
@@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
- flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
- if flf_weight is not None:
- dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
+ flf_key = '{}img_emb.emb_pos'.format(key_prefix)
+ if flf_key in state_dict_keys:
+ dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1]
- ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
- if ref_conv_weight is not None:
- dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
+ ref_conv_key = '{}ref_conv.weight'.format(key_prefix)
+ if ref_conv_key in state_dict_keys:
+ dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
- in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
+ in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix))
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
- dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
+ dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
@@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
- dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
+ dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1]
dit_config["context_dim"] = 1024
- dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
+ dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
@@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
- dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
+ dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
- dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
+ dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
@@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
- dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
+ dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
@@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
- model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
@@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
- dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["image_model"] = "kandinsky5"
- dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
- dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
+ dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0]
+ dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
@@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
if y_input in state_dict_keys:
unet_config["num_classes"] = "sequential"
- unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
+ unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1]
else:
unet_config["adm_in_channels"] = None
- model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
- in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
+ model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0]
+ in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
- out_channels = state_dict[out_key].shape[0]
+ out_channels = sd_shape(state_dict, out_key)[0]
else:
out_channels = 4
@@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
- last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
+ last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
@@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
- match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
+ match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1]
attn_res *= 2
if attn_blocks == 0:
@@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
match["transformer_depth"] = transformer_depth
- match["model_channels"] = state_dict["conv_in.weight"].shape[0]
- match["in_channels"] = state_dict["conv_in.weight"].shape[1]
+ match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0]
+ match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
- match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
+ match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1]
elif "add_embedding.linear_1.weight" in state_dict:
- match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
+ match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
@@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
- hidden_size = state_dict["x_embedder.bias"].shape[0]
+ hidden_size = sd_shape(state_dict, "x_embedder.bias")[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
- depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
+ depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
else:
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index e5de4a5b5..c681bbd69 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -26,6 +26,7 @@ import platform
import weakref
import gc
import os
+import comfy.disk_weights
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -540,7 +541,12 @@ class LoadedModel:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
- self.model.detach(unpatch_weights)
+ offload_device = None
+ if comfy.disk_weights.disk_weights_enabled():
+ offload_device = torch.device("meta")
+ self.model.detach(unpatch_weights, offload_device=offload_device)
+ if offload_device is not None and offload_device.type == "meta":
+ logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@@ -594,6 +600,11 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
+ if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
+ logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
+ freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
+ if freed_cache < memory_required:
+ evict_ram_to_disk(memory_required - freed_cache)
unloaded_model = []
can_unload = []
unloaded_models = []
@@ -629,6 +640,34 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
+
+def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
+ if memory_to_free <= 0:
+ return 0
+ if not comfy.disk_weights.disk_weights_enabled():
+ return 0
+
+ freed = 0
+ can_unload = []
+ for i in range(len(current_loaded_models) - 1, -1, -1):
+ shift_model = current_loaded_models[i]
+ if shift_model not in keep_loaded and not shift_model.is_dead():
+ loaded_memory = shift_model.model_loaded_memory()
+ if loaded_memory > 0:
+ can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
+
+ for x in sorted(can_unload):
+ i = x[-1]
+ memory_needed = memory_to_free - freed
+ if memory_needed <= 0:
+ break
+ logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
+ freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
+
+ if freed > 0:
+ logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
+ return freed
+
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -1135,6 +1174,16 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
+WEIGHTS_RAM_CACHE_BYTES = 0
+WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
+if args.weights_ram_cache_gb is not None:
+ WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
+ comfy.disk_weights.configure(
+ WEIGHTS_RAM_CACHE_BYTES,
+ allow_gds=WEIGHTS_GDS_ENABLED,
+ pin_if_cpu=not args.disable_pinned_memory,
+ )
+
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
@@ -1291,7 +1340,10 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
+ if hasattr(dev, 'type') and dev.type == "meta":
+ mem_free_total = sys.maxsize
+ mem_free_torch = mem_free_total
+ elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index f6b80a40f..26fd30fa8 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -34,6 +34,7 @@ import comfy.lora
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
+import comfy.disk_weights
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@@ -269,6 +270,8 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
+ comfy.disk_weights.attach_disk_weight_hooks(self.model)
+
def model_size(self):
if self.size > 0:
return self.size
@@ -783,7 +786,7 @@ class ModelPatcher:
m.comfy_patched_weights = True
for x in load_completely:
- x[2].to(device_to)
+ comfy.disk_weights.module_to(x[2], device_to)
for x in offloaded:
n = x[1]
@@ -799,7 +802,7 @@ class ModelPatcher:
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
- self.model.to(device_to)
+ comfy.disk_weights.module_to(self.model, device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
@@ -856,7 +859,7 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
- self.model.to(device_to)
+ comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@@ -883,6 +886,9 @@ class ModelPatcher:
if len(unload_list) > 0:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
+ remaining_ram = None
+ if device_to is not None and comfy.model_management.is_device_cpu(device_to):
+ remaining_ram = comfy.model_management.get_free_memory(device_to)
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
@@ -916,7 +922,24 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
- m.to(device_to)
+ freed_bytes = module_mem
+ if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
+ freed_bytes = comfy.disk_weights.offload_module_weights(m)
+ if freed_bytes == 0:
+ freed_bytes = module_mem
+ else:
+ if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
+ logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
+ freed_bytes = comfy.disk_weights.offload_module_weights(m)
+ if freed_bytes == 0:
+ freed_bytes = module_mem
+ else:
+ if comfy.disk_weights.disk_weights_enabled():
+ comfy.disk_weights.move_module_tensors(m, device_to)
+ else:
+ m.to(device_to)
+ if remaining_ram is not None:
+ remaining_ram = max(0, remaining_ram - module_mem)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
@@ -939,7 +962,7 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
- memory_freed += module_mem
+ memory_freed += freed_bytes
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
@@ -953,7 +976,8 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
- logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
+ target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to
+ logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@@ -984,11 +1008,12 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
- def detach(self, unpatch_all=True):
+ def detach(self, unpatch_all=True, offload_device=None):
self.eject_model()
self.model_patches_to(self.offload_device)
+ target_device = self.offload_device if offload_device is None else offload_device
if unpatch_all:
- self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
+ self.unpatch_model(target_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model
@@ -1358,4 +1383,3 @@ class ModelPatcher:
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)
-
diff --git a/comfy/ops.py b/comfy/ops.py
index 8156c42ff..d7a0c9ab2 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -19,6 +19,7 @@
import torch
import logging
import comfy.model_management
+import comfy.disk_weights
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
@@ -98,11 +99,35 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
- weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
+ weight_source = s.weight
+ bias_source = s.bias
+ if comfy.disk_weights.disk_weights_enabled():
+ if weight_source.device.type == "meta":
+ loaded = comfy.disk_weights.load_module_tensor(
+ s,
+ "weight",
+ device,
+ temporary=True,
+ dtype_override=dtype,
+ )
+ if loaded is not None:
+ weight_source = loaded
+ if bias_source is not None and bias_source.device.type == "meta":
+ loaded_bias = comfy.disk_weights.load_module_tensor(
+ s,
+ "bias",
+ device,
+ temporary=True,
+ dtype_override=bias_dtype,
+ )
+ if loaded_bias is not None:
+ bias_source = loaded_bias
+
+ weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
- if s.bias is not None:
- bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
+ if bias_source is not None:
+ bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
@@ -532,9 +557,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
- value = value.to(device=device)
- if dtype is not None:
- value = value.view(dtype=dtype)
+ if value.device.type != "meta":
+ value = value.to(device=device)
+ if dtype is not None:
+ value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
@@ -551,11 +577,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
- if layer_conf is not None:
+ if layer_conf is not None and layer_conf.device.type != "meta":
layer_conf = json.loads(layer_conf.numpy().tobytes())
+ elif layer_conf is not None:
+ layer_conf = None
if layer_conf is None:
- self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
+ if weight.device.type == "meta":
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
+ else:
+ self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
@@ -601,10 +632,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
- self.weight = torch.nn.Parameter(
- QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
- requires_grad=False
- )
+ if weight.device.type == "meta":
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
+ else:
+ self.weight = torch.nn.Parameter(
+ QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
+ requires_grad=False
+ )
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
@@ -614,7 +648,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_v = state_dict.pop(param_key, None)
if _v is None:
continue
- self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
+ if _v.device.type == "meta":
+ self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
+ else:
+ self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
diff --git a/comfy/safetensors_stream.py b/comfy/safetensors_stream.py
new file mode 100644
index 000000000..61943e53f
--- /dev/null
+++ b/comfy/safetensors_stream.py
@@ -0,0 +1,934 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+from __future__ import annotations
+
+import collections
+import ctypes
+import importlib
+import importlib.util
+import os
+import threading
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
+
+import torch
+
+
+_FST_MODULE = None
+_FST_LOCK = threading.Lock()
+_FST_LOADED = False
+_GDS_INITIALIZED = False
+_MISSING = object()
+_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
+
+
+def _require_fastsafetensors():
+ global _FST_MODULE
+ with _FST_LOCK:
+ if _FST_MODULE is None:
+ if importlib.util.find_spec("fastsafetensors") is None:
+ raise ImportError(
+ "fastsafetensors is required for safetensors streaming. "
+ "Install it with: pip install 'fastsafetensors @ https://github.com/"
+ "foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
+ )
+ _FST_MODULE = importlib.import_module("fastsafetensors")
+ return _FST_MODULE
+
+
+def _init_fastsafetensors_lib():
+ global _FST_LOADED
+ fst = _require_fastsafetensors()
+ if not _FST_LOADED:
+ fst.cpp.load_library_functions()
+ _FST_LOADED = True
+ return fst
+
+
+def _init_gds():
+ global _GDS_INITIALIZED
+ fst = _init_fastsafetensors_lib()
+ if not _GDS_INITIALIZED:
+ if fst.cpp.init_gds() != 0:
+ raise RuntimeError("fastsafetensors init_gds() failed")
+ _GDS_INITIALIZED = True
+
+
+@dataclass(frozen=True)
+class TensorMeta:
+ dtype: torch.dtype
+ shape: Tuple[int, ...]
+ numel: int
+ nbytes: int
+ data_offsets: Tuple[int, int]
+ filename: str
+ fst_dtype: object
+ strides: Tuple[int, ...]
+
+
+class SafeTensorIndex:
+ def __init__(self, filename: str):
+ fst = _init_fastsafetensors_lib()
+ framework = fst.frameworks.get_framework_op("pytorch")
+ metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
+ self._filename = filename
+ self._metadata = metadata
+ self._framework = framework
+ from fastsafetensors.frameworks import _torch as fst_torch
+ self._dtype_map = fst_torch.dtype_convert
+ self._tensor_meta: Dict[str, TensorMeta] = {}
+ for key, frame in metadata.tensors.items():
+ torch_dtype = self._dtype_map.get(frame.dtype, None)
+ if torch_dtype is None:
+ raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
+ numel = 1
+ for s in frame.shape:
+ numel *= s
+ nbytes = numel * framework.get_dtype_size(frame.dtype)
+ self._tensor_meta[key] = TensorMeta(
+ dtype=torch_dtype,
+ shape=tuple(frame.shape),
+ numel=numel,
+ nbytes=nbytes,
+ data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
+ filename=filename,
+ fst_dtype=frame.dtype,
+ strides=tuple(frame.strides),
+ )
+
+ def keys(self) -> Iterable[str]:
+ return self._tensor_meta.keys()
+
+ def has(self, key: str) -> bool:
+ return key in self._tensor_meta
+
+ def meta(self, key: str) -> TensorMeta:
+ return self._tensor_meta[key]
+
+ def metadata(self):
+ return self._metadata.metadata
+
+ @property
+ def header_length(self) -> int:
+ return self._metadata.header_length
+
+ @property
+ def size_bytes(self) -> int:
+ return self._metadata.size_bytes
+
+
+class _SafeTensorFile:
+ def __init__(self, filename: str, index: SafeTensorIndex):
+ self.filename = filename
+ self.index = index
+ self._fd: Optional[int] = None
+ self._gds_handle = None
+ self._gds_reader = None
+ self._nogds_reader = None
+ self._refcount = 1
+
+ def acquire(self) -> "_SafeTensorFile":
+ self._refcount += 1
+ return self
+
+ def release(self):
+ self._refcount -= 1
+ if self._refcount <= 0:
+ self.close()
+
+ def close(self):
+ if self._fd is not None:
+ os.close(self._fd)
+ self._fd = None
+ self._gds_handle = None
+
+ def _ensure_fd(self) -> int:
+ if self._fd is None:
+ self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
+ return self._fd
+
+ def _ensure_nogds_reader(self, use_cuda: bool):
+ fst = _init_fastsafetensors_lib()
+ if self._nogds_reader is None:
+ self._nogds_reader = fst.cpp.nogds_file_reader(
+ False, 16 * 1024, 16, use_cuda
+ )
+ return self._nogds_reader
+
+ def _ensure_gds_reader(self, use_cuda: bool):
+ fst = _init_fastsafetensors_lib()
+ if self._gds_reader is None:
+ self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
+ return self._gds_reader
+
+ def _ensure_gds_handle(self, use_cuda: bool):
+ if self._gds_handle is None:
+ fst = _init_fastsafetensors_lib()
+ framework = fst.frameworks.get_framework_op("pytorch")
+ o_direct = _get_gds_o_direct(framework)
+ self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
+ return self._gds_handle
+
+ def read_tensor(
+ self,
+ meta: TensorMeta,
+ device: torch.device,
+ dtype: Optional[torch.dtype],
+ allow_gds: bool,
+ pin_if_cpu: bool,
+ ) -> torch.Tensor:
+ fst = _init_fastsafetensors_lib()
+ framework = fst.frameworks.get_framework_op("pytorch")
+ device_is_cuda = device.type == "cuda"
+ if device_is_cuda and allow_gds:
+ _ensure_gds_ready(device)
+ tensor = self._read_tensor_gds(
+ fst, framework, meta, device, dtype
+ )
+ return tensor
+
+ cpu_tensor = self._read_tensor_nogds(
+ fst, framework, meta, torch.device("cpu"), dtype
+ )
+ if device_is_cuda:
+ if pin_if_cpu:
+ cpu_tensor = cpu_tensor.pin_memory()
+ gpu_tensor = torch.empty_like(cpu_tensor, device=device)
+ gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
+ return gpu_tensor
+ return cpu_tensor
+
+ def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
+ fst = _init_fastsafetensors_lib()
+ align = fst.cpp.get_alignment_size()
+ aligned_offset = (abs_start // align) * align
+ head = abs_start - aligned_offset
+ aligned_length = length + head
+ tail = aligned_length % align
+ if tail:
+ aligned_length += align - tail
+ return aligned_offset, aligned_length, head
+
+ def _read_tensor_nogds(
+ self,
+ fst,
+ framework,
+ meta: TensorMeta,
+ device: torch.device,
+ dtype: Optional[torch.dtype],
+ ) -> torch.Tensor:
+ fd = self._ensure_fd()
+ reader = self._ensure_nogds_reader(use_cuda=False)
+ abs_start = self.index.header_length + meta.data_offsets[0]
+ length = meta.data_offsets[1] - meta.data_offsets[0]
+ chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
+ chunk_bytes = max(1, chunk_bytes)
+ ptr_align = framework.get_device_ptr_align()
+ dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
+ buffer_length = 0
+ buf_ptr = None
+ gbuf = None
+ try:
+ chunk_offset = 0
+ while chunk_offset < length:
+ chunk_len = min(length - chunk_offset, chunk_bytes)
+ aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
+ needed = aligned_length + ptr_align
+ if buf_ptr is None or needed > buffer_length:
+ if buf_ptr is not None:
+ fst.cpp.cpu_free(buf_ptr)
+ buffer_length = needed
+ buf_ptr = fst.cpp.cpu_malloc(buffer_length)
+ gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
+ ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
+ req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
+ if reader.wait_read(req) < 0:
+ raise RuntimeError("nogds_file_reader read failed")
+ src_ptr = gbuf.get_base_address() + ptr_off + head
+ dest_ptr = dest_tensor.data_ptr() + chunk_offset
+ ctypes.memmove(dest_ptr, src_ptr, chunk_len)
+ chunk_offset += chunk_len
+ except Exception:
+ if buf_ptr is not None:
+ fst.cpp.cpu_free(buf_ptr)
+ raise
+ if buf_ptr is not None:
+ fst.cpp.cpu_free(buf_ptr)
+ if dtype is not None and dtype != dest_tensor.dtype:
+ _validate_dtype_conversion(dest_tensor.dtype, dtype)
+ dest_tensor = dest_tensor.to(dtype=dtype)
+ return dest_tensor
+
+ def _read_tensor_gds(
+ self,
+ fst,
+ framework,
+ meta: TensorMeta,
+ device: torch.device,
+ dtype: Optional[torch.dtype],
+ ) -> torch.Tensor:
+ reader = self._ensure_gds_reader(use_cuda=True)
+ handle = self._ensure_gds_handle(use_cuda=True)
+ abs_start = self.index.header_length + meta.data_offsets[0]
+ length = meta.data_offsets[1] - meta.data_offsets[0]
+ aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
+ ptr_align = framework.get_device_ptr_align()
+ buffer_length = aligned_length + ptr_align
+ fst_device = _fst_device_from_torch(fst, device)
+ gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
+ ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
+ file_length = self.index.size_bytes
+ req = reader.submit_read(
+ handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
+ )
+ if reader.wait_read(req) < 0:
+ framework.free_tensor_memory(gbuf, fst_device)
+ raise RuntimeError("gds_file_reader read failed")
+ owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
+ tensor = _dlpack_tensor_from_buffer(
+ fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
+ )
+ if dtype is not None and dtype != tensor.dtype:
+ _validate_dtype_conversion(tensor.dtype, dtype)
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+
+
+def _fst_device_from_torch(fst, device: torch.device):
+ if device.type == "cuda" and device.index is not None:
+ return fst.st_types.Device.from_str(f"cuda:{device.index}")
+ return fst.st_types.Device.from_str(device.type)
+
+
+class _BufferOwner:
+ def __init__(self, free_fn):
+ self._free_fn = free_fn
+
+ def __del__(self):
+ try:
+ self._free_fn()
+ except Exception:
+ pass
+
+
+def _dlpack_tensor_from_buffer(
+ fst,
+ framework,
+ ptr: int,
+ meta: TensorMeta,
+ device: torch.device,
+ owner: Optional[_BufferOwner],
+) -> torch.Tensor:
+ disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
+ dev = _fst_device_from_torch(fst, device)
+ dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
+ torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
+ if disk_dtype != meta.fst_dtype:
+ torch_tensor = torch_tensor.view(meta.dtype)
+ if owner is not None:
+ torch_tensor._comfy_disk_buffer_owner = owner
+ return torch_tensor
+
+
+def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
+ if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
+ raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
+
+
+def _get_gds_o_direct(framework) -> bool:
+ cuda_ver = framework.get_cuda_ver()
+ if cuda_ver and cuda_ver != "0.0":
+ ver_parts = cuda_ver.split("-", 1)
+ if len(ver_parts) == 2:
+ cudavers = list(map(int, ver_parts[1].split(".")))
+ if ver_parts[0] == "cuda":
+ return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
+ return True
+ return True
+
+
+def _ensure_gds_ready(device: torch.device):
+ fst = _init_fastsafetensors_lib()
+ if not fst.common.is_gpu_found():
+ raise RuntimeError(
+ "GPUDirect requested but GPU runtime library is missing. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
+ if gds_supported < 0:
+ raise RuntimeError(
+ "GPUDirect requested but is_gds_supported() failed. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ if not fst.cpp.is_cufile_found():
+ raise RuntimeError(
+ "GPUDirect requested but libcufile is missing. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ if gds_supported == 0:
+ raise RuntimeError(
+ "GPUDirect requested but GDS is unsupported on this platform. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ _init_gds()
+
+
+class StreamStateDict(collections.abc.MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(
+ self,
+ index: SafeTensorIndex,
+ file: _SafeTensorFile,
+ device: torch.device,
+ allow_gds: bool = False,
+ ):
+ self._index = index
+ self._file = file
+ self._device = device
+ self._allow_gds = allow_gds
+ self._overrides: Dict[str, torch.Tensor] = {}
+ self._deleted: set[str] = set()
+
+ @classmethod
+ def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
+ index = SafeTensorIndex(filename)
+ file = _SafeTensorFile(filename, index)
+ return cls(index, file, device, allow_gds=allow_gds)
+
+ def close(self):
+ if self._file is not None:
+ self._file.release()
+ self._file = None
+
+ def __del__(self):
+ try:
+ self.close()
+ except Exception:
+ pass
+
+ def meta(self, key: str) -> TensorMeta:
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return TensorMeta(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ data_offsets=(0, numel * t.element_size()),
+ filename="",
+ fst_dtype=None,
+ strides=tuple(t.stride()),
+ )
+ if key in self._deleted:
+ raise KeyError(key)
+ return self._index.meta(key)
+
+ def get_tensor(
+ self,
+ key: str,
+ *,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ ) -> torch.Tensor:
+ if key in self._overrides:
+ t = self._overrides[key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if dtype is not None and t.dtype != dtype:
+ _validate_dtype_conversion(t.dtype, dtype)
+ t = t.to(dtype=dtype)
+ return t
+ if key in self._deleted:
+ raise KeyError(key)
+ if device is None:
+ device = self._device
+ if device.type == "meta":
+ meta = self._index.meta(key)
+ target_dtype = dtype or meta.dtype
+ if dtype is not None and dtype != meta.dtype:
+ _validate_dtype_conversion(meta.dtype, dtype)
+ return torch.empty(meta.shape, dtype=target_dtype, device="meta")
+ if allow_gds is None:
+ allow_gds = self._allow_gds
+ meta = self._index.meta(key)
+ return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ return self.get_tensor(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ if key in self._deleted:
+ raise KeyError(key)
+ if self._index.has(key):
+ self._deleted.add(key)
+ return
+ raise KeyError(key)
+
+ def __iter__(self) -> Iterator[str]:
+ for k in self._index.keys():
+ if k in self._deleted:
+ continue
+ if k in self._overrides:
+ continue
+ yield k
+ for k in self._overrides.keys():
+ yield k
+
+ def __len__(self) -> int:
+ base = len(self._index.keys())
+ return base - len(self._deleted) + len(self._overrides)
+
+ def __contains__(self, key: object) -> bool:
+ if not isinstance(key, str):
+ return False
+ if key in self._deleted:
+ return False
+ if key in self._overrides:
+ return True
+ return self._index.has(key)
+
+ def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides.pop(key)
+ if key in self._deleted:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ if self._index.has(key):
+ self._deleted.add(key)
+ return self.get_tensor(key)
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+
+ def copy(self) -> "StreamStateDict":
+ new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
+ new._overrides = dict(self._overrides)
+ new._deleted = set(self._deleted)
+ return new
+
+ def metadata(self):
+ return self._index.metadata()
+
+
+class _BaseViewStateDict(MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(self, base: MutableMapping, mutate_base: bool = False):
+ self._base = base
+ self._mutate_base = mutate_base
+ self._overrides: Dict[str, torch.Tensor] = {}
+ self._deleted: set[str] = set()
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return key
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._base.keys()
+
+ def get_tensor(
+ self,
+ key: str,
+ *,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ ) -> torch.Tensor:
+ if key in self._overrides:
+ t = self._overrides[key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if dtype is not None and t.dtype != dtype:
+ _validate_dtype_conversion(t.dtype, dtype)
+ t = t.to(dtype=dtype)
+ return t
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if hasattr(self._base, "get_tensor"):
+ return self._base.get_tensor(
+ base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
+ )
+ t = self._base[base_key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if dtype is not None and t.dtype != dtype:
+ _validate_dtype_conversion(t.dtype, dtype)
+ t = t.to(dtype=dtype)
+ return t
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ return self.get_tensor(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ base_key = self._resolve_base_key(key)
+ if self._mutate_base and base_key is not None and base_key in self._base:
+ self._base[base_key] = value
+ else:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if self._mutate_base and base_key in self._base:
+ del self._base[base_key]
+ else:
+ self._deleted.add(key)
+
+ def __iter__(self) -> Iterator[str]:
+ for k in self._iter_base_keys():
+ if k in self._deleted:
+ continue
+ yield k
+ for k in self._overrides.keys():
+ yield k
+
+ def __len__(self) -> int:
+ base_keys = list(self._iter_base_keys())
+ return len(base_keys) - len(self._deleted) + len(self._overrides)
+
+ def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides.pop(key)
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ if self._mutate_base:
+ try:
+ return self._base.pop(base_key)
+ except KeyError:
+ if default is _MISSING:
+ raise
+ return default
+ self._deleted.add(key)
+ return self.get_tensor(key)
+
+ def meta(self, key: str):
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if hasattr(self._base, "meta"):
+ return self._base.meta(base_key)
+ t = self._base[base_key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+
+
+class DeviceViewStateDict(_BaseViewStateDict):
+ def __init__(
+ self,
+ base: MutableMapping,
+ device: torch.device,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ mutate_base: bool = False,
+ ):
+ super().__init__(base, mutate_base=mutate_base)
+ self._device = device
+ self._allow_gds = allow_gds
+ self._pin_if_cpu = pin_if_cpu
+
+ def get_tensor(
+ self,
+ key: str,
+ *,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ ) -> torch.Tensor:
+ device = self._device if device is None else device
+ allow_gds = self._allow_gds if allow_gds is None else allow_gds
+ pin_if_cpu = self._pin_if_cpu if not pin_if_cpu else pin_if_cpu
+ return super().get_tensor(
+ key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
+ )
+
+ def meta(self, key: str):
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if hasattr(self._base, "meta"):
+ return self._base.meta(base_key)
+ t = self._base[base_key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ return self.get_tensor(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ base_key = self._resolve_base_key(key)
+ if self._mutate_base and base_key is not None and base_key in self._base:
+ self._base[base_key] = value
+ else:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if self._mutate_base and base_key in self._base:
+ del self._base[base_key]
+ else:
+ self._deleted.add(key)
+
+ def __iter__(self) -> Iterator[str]:
+ for k in self._iter_base_keys():
+ if k in self._deleted:
+ continue
+ yield k
+ for k in self._overrides.keys():
+ yield k
+
+ def __len__(self) -> int:
+ base_keys = list(self._iter_base_keys())
+ return len(base_keys) - len(self._deleted) + len(self._overrides)
+
+ def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides.pop(key)
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ if self._mutate_base:
+ try:
+ return self._base.pop(base_key)
+ except KeyError:
+ if default is _MISSING:
+ raise
+ return default
+ self._deleted.add(key)
+ return self.get_tensor(key)
+
+
+class FilterViewStateDict(_BaseViewStateDict):
+ def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
+ super().__init__(base, mutate_base=mutate_base)
+ self._predicate = predicate
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ if self._predicate(key):
+ return key
+ return None
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ for k in self._base.keys():
+ if self._predicate(k):
+ yield k
+
+
+class PrefixViewStateDict(_BaseViewStateDict):
+ def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
+ super().__init__(base, mutate_base=mutate_base)
+ self._source_prefix = source_prefix
+ self._target_prefix = target_prefix
+ self._mapping: Dict[str, str] = {}
+ self._reverse: Dict[str, str] = {}
+ for k in base.keys():
+ if not k.startswith(source_prefix):
+ continue
+ view_key = f"{target_prefix}{k[len(source_prefix):]}"
+ self._mapping[k] = view_key
+ self._reverse[view_key] = k
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return self._reverse.get(key)
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._reverse.keys()
+
+
+class RenameViewStateDict(_BaseViewStateDict):
+ def __init__(
+ self,
+ base: MutableMapping,
+ replace_prefix: Mapping[str, str],
+ filter_keys: bool = False,
+ mutate_base: bool = False,
+ ):
+ super().__init__(base, mutate_base=mutate_base)
+ self._filter_keys = filter_keys
+ self._replace = list(replace_prefix.items())
+ self._mapping: Dict[str, str] = {}
+ self._reverse: Dict[str, str] = {}
+ for k in base.keys():
+ view_key = self._replace_key(k)
+ if view_key is None:
+ continue
+ self._mapping[k] = view_key
+ self._reverse[view_key] = k
+
+ def _replace_key(self, key: str) -> Optional[str]:
+ for rp, dst in self._replace:
+ if key.startswith(rp):
+ return f"{dst}{key[len(rp):]}"
+ if self._filter_keys:
+ return None
+ return key
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return self._reverse.get(key)
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._reverse.keys()
+
+
+class MergedStateDict(MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(self, *mappings: MutableMapping):
+ self._mappings = list(mappings)
+ self._overrides: Dict[str, torch.Tensor] = {}
+ self._deleted: set[str] = set()
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides[key]
+ if key in self._deleted:
+ raise KeyError(key)
+ for mapping in reversed(self._mappings):
+ if key in mapping:
+ if hasattr(mapping, "get_tensor"):
+ return mapping.get_tensor(key)
+ return mapping[key]
+ raise KeyError(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ if key in self._deleted:
+ raise KeyError(key)
+ if any(key in mapping for mapping in self._mappings):
+ self._deleted.add(key)
+ return
+ raise KeyError(key)
+
+ def __iter__(self) -> Iterator[str]:
+ seen = set()
+ for mapping in self._mappings:
+ for key in mapping.keys():
+ if key in self._deleted or key in seen:
+ continue
+ seen.add(key)
+ yield key
+ for key in self._overrides.keys():
+ if key not in seen:
+ yield key
+
+ def __len__(self) -> int:
+ return len(list(self.__iter__()))
+
+ def meta(self, key: str):
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ if key in self._deleted:
+ raise KeyError(key)
+ for mapping in reversed(self._mappings):
+ if key in mapping:
+ if hasattr(mapping, "meta"):
+ return mapping.meta(key)
+ t = mapping[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ raise KeyError(key)
+
+
+class MappedStateDict(_BaseViewStateDict):
+ def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
+ super().__init__(base, mutate_base=mutate_base)
+ self._base_to_view = dict(key_map)
+ self._view_to_base = {v: k for k, v in key_map.items()}
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return self._view_to_base.get(key)
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._view_to_base.keys()
diff --git a/comfy/sd.py b/comfy/sd.py
index 5a7221620..4f4b7298b 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -25,6 +25,8 @@ import math
import os
import comfy.utils
+import comfy.safetensors_stream
+import comfy.disk_weights
from . import clip_vision
from . import gligen
@@ -124,7 +126,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
- self.cond_stage_model.to(offload_device)
+ comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
@@ -288,7 +290,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
- return self.cond_stage_model.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@@ -349,7 +351,7 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
- self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
+ self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
@@ -364,25 +366,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
- new_sd = {}
- for k in sd:
- new_sd["encoder.{}".format(k)] = sd[k]
- sd = new_sd
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
- new_sd = {}
- for k in sd:
- new_sd["previewer.{}".format(k)] = sd[k]
- sd = new_sd
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."})
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd:
- if sd['decoder.conv_in.weight'].shape[1] == 64:
+ if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@@ -392,9 +388,9 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
- elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
+ elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@@ -417,7 +413,7 @@ class VAE:
self.downscale_ratio = 4
self.upscale_ratio = 4
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
@@ -430,7 +426,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd:
- self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
+ self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@@ -465,11 +461,11 @@ class VAE:
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
- tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
+ tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight")
version = 0
- if tensor_conv1.shape[0] == 512:
+ if tensor_conv1_shape[0] == 512:
version = 0
- elif tensor_conv1.shape[0] == 1024:
+ elif tensor_conv1_shape[0] == 1024:
version = 1
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
@@ -486,9 +482,9 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
- elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
+ elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
- ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
+ ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@@ -512,8 +508,8 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
- self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
+ self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
#This is likely to significantly over-estimate with single image or low frame counts as the
#implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
@@ -546,14 +542,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
- dim = sd["decoder.head.0.gamma"].shape[0]
+ dim = sd_shape(sd, "decoder.head.0.gamma")[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
- self.output_channels = sd["encoder.conv1.weight"].shape[1]
+ self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
@@ -629,7 +625,7 @@ class VAE:
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
- self.latent_channels = sd["decoder.1.weight"].shape[1]
+ self.latent_channels = sd_shape(sd, "decoder.1.weight")[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@@ -640,12 +636,12 @@ class VAE:
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
- elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
+ elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
- if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
+ if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
@@ -665,7 +661,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
- m, u = self.first_stage_model.load_state_dict(sd, strict=False)
+ m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@@ -679,7 +675,7 @@ class VAE:
if dtype is None:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
- self.first_stage_model.to(self.vae_dtype)
+ comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@@ -986,9 +982,12 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
- model.load_state_dict(model_data)
+ comfy.utils.load_state_dict(model, model_data, strict=True)
return StyleModel(model)
+def sd_shape(state_dict, key):
+ return comfy.utils.state_dict_meta(state_dict, key).shape
+
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
@@ -1058,16 +1057,16 @@ def detect_te_model(sd):
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
- weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
- if weight.shape[-1] == 4096:
+ weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight")
+ if weight_shape[-1] == 4096:
return TEModel.T5_XXL
- elif weight.shape[-1] == 2048:
+ elif weight_shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
- weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
- if weight.shape[0] == 384:
+ weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight')
+ if weight_shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
@@ -1077,19 +1076,19 @@ def detect_te_model(sd):
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
- weight = sd['model.layers.0.self_attn.k_proj.bias']
- if weight.shape[0] == 256:
+ weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias')
+ if weight_shape[0] == 256:
return TEModel.QWEN25_3B
- if weight.shape[0] == 512:
+ if weight_shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
- weight = sd['model.layers.0.post_attention_layernorm.weight']
+ weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight')
if 'model.layers.0.self_attn.q_norm.weight' in sd:
- if weight.shape[0] == 2560:
+ if weight_shape[0] == 2560:
return TEModel.QWEN3_4B
- elif weight.shape[0] == 2048:
+ elif weight_shape[0] == 2048:
return TEModel.QWEN3_2B
- if weight.shape[0] == 5120:
+ if weight_shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
@@ -1418,19 +1417,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
- out_sd = {}
- for k in sd:
- skip = False
+ if comfy.utils.is_stream_state_dict(sd):
+ def _keep_key(k, prefixes=tuple(scaled_fp8_list)):
+ return not any(k.startswith(pref) for pref in prefixes)
+ out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False)
+ merged = out_sd
for pref in scaled_fp8_list:
- skip = skip or k.startswith(pref)
- if not skip:
- out_sd[k] = sd[k]
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd)
+ sd = merged
+ else:
+ out_sd = {}
+ for k in sd:
+ skip = False
+ for pref in scaled_fp8_list:
+ skip = skip or k.startswith(pref)
+ if not skip:
+ out_sd[k] = sd[k]
- for pref in scaled_fp8_list:
- quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
- for k in quant_sd:
- out_sd[k] = quant_sd[k]
- sd = out_sd
+ for pref in scaled_fp8_list:
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ for k in quant_sd:
+ out_sd[k] = quant_sd[k]
+ sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
@@ -1508,12 +1517,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
- new_sd = {}
- for k in diffusers_keys:
- if k in sd:
- new_sd[diffusers_keys[k]] = sd.pop(k)
- else:
- logging.warning("{} {}".format(diffusers_keys[k], k))
+ if comfy.utils.is_stream_state_dict(sd):
+ new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys)
+ else:
+ new_sd = {}
+ for k in diffusers_keys:
+ if k in sd:
+ new_sd[diffusers_keys[k]] = sd.pop(k)
+ else:
+ logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
@@ -1538,7 +1550,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
- model = model.to(offload_device)
+ model = comfy.disk_weights.module_to(model, offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index c512ca5d0..9fb8a8f5c 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
- return self.transformer.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
def parse_parentheses(string):
result = []
@@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
try:
if embed_path.lower().endswith(".safetensors"):
- import safetensors.torch
- embed = safetensors.torch.load_file(embed_path, device="cpu")
+ embed = comfy.utils.load_torch_file(embed_path, safe_load=True)
else:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py
index ce36f1a84..7cf040588 100644
--- a/comfy/taesd/taesd.py
+++ b/comfy/taesd/taesd.py
@@ -56,9 +56,9 @@ class TAESD(nn.Module):
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
- self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
+ comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
if decoder_path is not None:
- self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
+ comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
@staticmethod
def scale_latents(x):
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 776e25e97..6d6e319dc 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
- return self.load_state_dict(sdo, strict=False)
+ return comfy.utils.load_state_dict(self, sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
diff --git a/comfy/utils.py b/comfy/utils.py
index ffa98c9b1..dda041625 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -26,10 +26,13 @@ import numpy as np
from PIL import Image
import logging
import itertools
+from types import SimpleNamespace
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
+from . import safetensors_stream
+import comfy.disk_weights
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
- with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
- sd = {}
- for k in f.keys():
- tensor = f.get_tensor(k)
- if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
- tensor = tensor.to(device=device, copy=True)
- sd[k] = tensor
- if return_metadata:
- metadata = f.metadata()
+ sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device)
+ if return_metadata:
+ metadata = sd.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
@@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
- w = sd[k]
- params += w.nelement()
+ meta = state_dict_meta(sd, k)
+ params += meta.numel
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
- w = sd[k]
- dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
+ meta = state_dict_meta(sd, k)
+ dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel
if len(dtypes) == 0:
return None
@@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace):
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
+ if is_stream_state_dict(state_dict):
+ return safetensors_stream.RenameViewStateDict(
+ state_dict,
+ replace_prefix,
+ filter_keys=filter_keys,
+ mutate_base=not filter_keys,
+ )
if filter_keys:
out = {}
else:
@@ -145,6 +149,79 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
return out
+def is_stream_state_dict(state_dict) -> bool:
+ return getattr(state_dict, "is_stream_state_dict", False)
+
+
+def state_dict_meta(state_dict, key):
+ if hasattr(state_dict, "meta"):
+ return state_dict.meta(key)
+ w = state_dict[key]
+ numel = w.numel()
+ return SimpleNamespace(
+ dtype=w.dtype,
+ shape=tuple(w.shape),
+ numel=numel,
+ nbytes=numel * w.element_size(),
+ )
+
+
+def load_state_dict(model, state_dict, strict=False, assign=False):
+ if is_stream_state_dict(state_dict):
+ if comfy.disk_weights.disk_weights_enabled():
+ return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
+ comfy.disk_weights.register_module_weights(model, state_dict)
+ comfy.disk_weights.attach_disk_weight_hooks(model)
+ missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
+ return missing, unexpected
+ return model.load_state_dict(state_dict, strict=strict)
+
+
+def stream_load_state_dict(model, state_dict, strict=False, assign=False):
+ if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
+ state_dict = state_dict.copy()
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ metadata = getattr(state_dict, "_metadata", None)
+
+ def load(module, local_state_dict, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ if assign:
+ local_metadata["assign_to_params_buffers"] = assign
+ module._load_from_state_dict(
+ local_state_dict,
+ prefix,
+ local_metadata,
+ True,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ child_prefix = f"{prefix}{name}."
+ child_state_dict = safetensors_stream.FilterViewStateDict(
+ local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
+ )
+ load(child, child_state_dict, child_prefix)
+ incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
+ for hook in module._load_state_dict_post_hooks.values():
+ out = hook(module, incompatible)
+ if out is not None:
+ raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
+
+ load(model, state_dict)
+ if strict:
+ if len(unexpected_keys) > 0:
+ error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
+ if len(missing_keys) > 0:
+ error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
+ if len(error_msgs) > 0:
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
+ return missing_keys, unexpected_keys
+
+
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
@@ -825,7 +902,10 @@ def copy_to_param(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
- prev.data.copy_(value)
+ if prev.device.type == "meta":
+ setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=prev.requires_grad))
+ else:
+ prev.data.copy_(value)
def get_attr(obj, attr: str):
"""Retrieves a nested attribute from an object using dot notation.
@@ -1217,46 +1297,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
- scaled_fp8_weight = state_dict[scaled_fp8_key]
- scaled_fp8_dtype = scaled_fp8_weight.dtype
+ if is_stream_state_dict(state_dict):
+ scaled_meta = state_dict_meta(state_dict, scaled_fp8_key)
+ scaled_fp8_dtype = scaled_meta.dtype
+ scaled_fp8_weight_nelements = scaled_meta.numel
+ else:
+ scaled_fp8_weight = state_dict[scaled_fp8_key]
+ scaled_fp8_dtype = scaled_fp8_weight.dtype
+ scaled_fp8_weight_nelements = scaled_fp8_weight.nelement()
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
- if scaled_fp8_weight.nelement() == 2:
+ if scaled_fp8_weight_nelements == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
- out_sd = {}
layers = {}
- for k in list(state_dict.keys()):
- if k == scaled_fp8_key:
- continue
- if not k.startswith(model_prefix):
- out_sd[k] = state_dict[k]
- continue
- k_out = k
- w = state_dict.pop(k)
- layer = None
- if k_out.endswith(".scale_weight"):
- layer = k_out[:-len(".scale_weight")]
- k_out = "{}.weight_scale".format(layer)
-
- if layer is not None:
- layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
- if full_precision_matrix_mult:
- layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
- layers[layer] = layer_conf
-
- if k_out.endswith(".scale_input"):
- layer = k_out[:-len(".scale_input")]
- k_out = "{}.input_scale".format(layer)
- if w.item() == 1.0:
+ if is_stream_state_dict(state_dict):
+ key_map = {}
+ for k in list(state_dict.keys()):
+ if k == scaled_fp8_key:
continue
+ if not k.startswith(model_prefix):
+ key_map[k] = k
+ continue
+ k_out = k
+ layer = None
+ if k_out.endswith(".scale_weight"):
+ layer = k_out[:-len(".scale_weight")]
+ k_out = "{}.weight_scale".format(layer)
- out_sd[k_out] = w
+ if layer is not None:
+ layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
+ if full_precision_matrix_mult:
+ layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
+ layers[layer] = layer_conf
- state_dict = out_sd
+ if k_out.endswith(".scale_input"):
+ layer = k_out[:-len(".scale_input")]
+ k_out = "{}.input_scale".format(layer)
+ scale_val = state_dict[k]
+ if scale_val.item() == 1.0:
+ continue
+
+ key_map[k] = k_out
+ state_dict = safetensors_stream.MappedStateDict(state_dict, key_map)
+ else:
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ if k == scaled_fp8_key:
+ continue
+ if not k.startswith(model_prefix):
+ out_sd[k] = state_dict[k]
+ continue
+ k_out = k
+ w = state_dict.pop(k)
+ layer = None
+ if k_out.endswith(".scale_weight"):
+ layer = k_out[:-len(".scale_weight")]
+ k_out = "{}.weight_scale".format(layer)
+
+ if layer is not None:
+ layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
+ if full_precision_matrix_mult:
+ layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
+ layers[layer] = layer_conf
+
+ if k_out.endswith(".scale_input"):
+ layer = k_out[:-len(".scale_input")]
+ k_out = "{}.input_scale".format(layer)
+ if w.item() == 1.0:
+ continue
+
+ out_sd[k_out] = w
+
+ state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])
diff --git a/nodes.py b/nodes.py
index 5a9d42d4a..a96f93e47 100644
--- a/nodes.py
+++ b/nodes.py
@@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
-import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@@ -526,7 +525,7 @@ class LoadLatent:
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
- latent = safetensors.torch.load_file(latent_path, device="cpu")
+ latent = comfy.utils.load_torch_file(latent_path, safe_load=True)
multiplier = 1.0
if "latent_format_version_0" not in latent:
multiplier = 1.0 / 0.18215
diff --git a/requirements.txt b/requirements.txt
index 7686a5f8a..11ec57e07 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,6 +11,7 @@ transformers>=4.50.3
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
+fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip
aiohttp>=3.11.8
yarl>=1.18.0
pyyaml
diff --git a/tests-unit/utils/safetensors_stream_test.py b/tests-unit/utils/safetensors_stream_test.py
new file mode 100644
index 000000000..60d36142d
--- /dev/null
+++ b/tests-unit/utils/safetensors_stream_test.py
@@ -0,0 +1,182 @@
+import os
+
+import pytest
+import importlib
+import importlib.util
+
+torch = pytest.importorskip("torch")
+
+
+def _write_safetensors(tmp_path, tensors):
+ import safetensors.torch
+ path = os.path.join(tmp_path, "test.safetensors")
+ safetensors.torch.save_file(tensors, path)
+ return path
+
+
+def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ meta = sd.meta("a")
+ assert meta.shape == (2, 3)
+ assert meta.dtype == torch.float32
+ assert meta.numel == 6
+ assert calls == []
+
+
+def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ _ = sd["a"]
+ assert len(calls) == 1
+ assert calls[0].shape == (2, 3)
+
+
+def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
+ _ = list(view.keys())
+ assert calls == []
+
+
+def test_stream_load_rss_small(tmp_path):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ psutil = pytest.importorskip("psutil")
+ process = psutil.Process()
+ size_elems = 4_000_000 # ~16MB float32
+ tensor = torch.zeros((size_elems,), dtype=torch.float32)
+ path = _write_safetensors(tmp_path, {"big": tensor})
+ rss_before = process.memory_info().rss
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ rss_after = process.memory_info().rss
+ expected_size = tensor.numel() * tensor.element_size()
+ assert (rss_after - rss_before) < expected_size
+ _ = sd.meta("big")
+
+
+def test_gds_path_errors_without_support(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ device = torch.device("cuda")
+
+ if importlib.util.find_spec("fastsafetensors") is None:
+ fst = None
+ else:
+ fst = importlib.import_module("fastsafetensors")
+
+ gds_available = False
+ if fst is not None and torch.cuda.is_available():
+ gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
+ gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
+
+ if not gds_available:
+ with pytest.raises(RuntimeError, match="GPUDirect requested"):
+ sd.get_tensor("a", device=device, allow_gds=True)
+ else:
+ def fail_nogds(*args, **kwargs):
+ raise AssertionError("nogds path used during GDS request")
+
+ monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
+ t = sd.get_tensor("a", device=device, allow_gds=True)
+ assert t.device.type == "cuda"
+
+
+def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ import comfy.disk_weights
+
+ prev_cache = comfy.disk_weights.CACHE.max_bytes
+ prev_gds = comfy.disk_weights.ALLOW_GDS
+ prev_pin = comfy.disk_weights.PIN_IF_CPU
+ prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
+ comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
+
+ try:
+ path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ model = torch.nn.Linear(4, 4, bias=True)
+ comfy.utils.load_state_dict(model, sd, strict=False)
+ assert model.weight.device.type == "cpu"
+ assert model.weight.device.type != "meta"
+ finally:
+ comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
+
+
+def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
+ if importlib.util.find_spec("fastsafetensors") is None:
+ pytest.skip("fastsafetensors not installed")
+ import comfy.utils
+ import comfy.disk_weights
+
+ prev_cache = comfy.disk_weights.CACHE.max_bytes
+ prev_gds = comfy.disk_weights.ALLOW_GDS
+ prev_pin = comfy.disk_weights.PIN_IF_CPU
+ prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
+ comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
+
+ try:
+ path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ model = torch.nn.Linear(4, 4, bias=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ comfy.utils.load_state_dict(model, sd, strict=True)
+ assert model.weight.device.type == "meta"
+ assert calls == []
+
+ comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
+ assert model.weight.device.type == "cpu"
+ assert len(calls) == 2
+ finally:
+ comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)