From f925f8fa774196e2271b6a253964af255c0146ee Mon Sep 17 00:00:00 2001
From: ifilipis <40601736+ifilipis@users.noreply.github.com>
Date: Thu, 8 Jan 2026 13:02:08 +0200
Subject: [PATCH] Add streaming safetensors loading with disk weight cache
---
DESIGN.md | 57 ++
README.md | 8 +
comfy/audio_encoders/audio_encoders.py | 2 +-
comfy/cli_args.py | 3 +
comfy/clip_vision.py | 2 +-
comfy/controlnet.py | 12 +-
comfy/disk_weights.py | 275 +++++++
comfy/gligen.py | 5 +-
comfy/ldm/hunyuan_video/upsampler.py | 3 +-
comfy/ldm/lightricks/vae/audio_vae.py | 5 +-
comfy/ldm/mmaudio/vae/vae.py | 4 +-
comfy/model_base.py | 14 +-
comfy/model_detection.py | 165 ++--
comfy/model_management.py | 13 +
comfy/model_patcher.py | 4 +-
comfy/safetensors_stream.py | 799 ++++++++++++++++++++
comfy/sd.py | 129 ++--
comfy/sd1_clip.py | 5 +-
comfy/taesd/taesd.py | 4 +-
comfy/text_encoders/lt.py | 2 +-
comfy/utils.py | 195 ++++-
nodes.py | 3 +-
requirements.txt | 1 +
tests-unit/utils/safetensors_stream_test.py | 146 ++++
24 files changed, 1640 insertions(+), 216 deletions(-)
create mode 100644 DESIGN.md
create mode 100644 comfy/disk_weights.py
create mode 100644 comfy/safetensors_stream.py
create mode 100644 tests-unit/utils/safetensors_stream_test.py
diff --git a/DESIGN.md b/DESIGN.md
new file mode 100644
index 000000000..8e74b256d
--- /dev/null
+++ b/DESIGN.md
@@ -0,0 +1,57 @@
+# Disk tier safetensors streaming design audit (ComfyUI)
+
+## Mandatory research audit (verified call sites)
+
+### ComfyUI load path + eager materialization sites
+- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】
+- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】
+- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】
+- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】
+- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】
+- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】
+- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】
+
+### FastSageTensors (fastsafetensors) capability audit
+- Header parsing and metadata:
+ - `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】
+ - `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】
+- GDS + no-GDS low-level readers:
+ - `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】
+ - GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】
+- DLPack wrapping:
+ - `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】
+- Torch framework interop:
+ - `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】
+
+### VRAM/RAM offload logic (for extension)
+- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】
+- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】
+
+## Strategy summary (no coding performed yet)
+
+### Streaming safetensors mapping (no full dict materialization)
+- Introduce a new module `comfy/safetensors_stream.py` with:
+ - `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`).
+ - `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand.
+ - Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading.
+
+### Range reads and tiering
+- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack.
+- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS.
+- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it.
+
+### Disk tier integration
+- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
+- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
+- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
+
+### Pipeline refactors
+- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
+- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy.
+- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
+- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
+- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
+
+### Tests and docs
+- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
+- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported.
diff --git a/README.md b/README.md
index 6d09758c0..550482e5c 100644
--- a/README.md
+++ b/README.md
@@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
+| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
+| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
+
+### Disk tier for model weights
+
+When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
+
+If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
# Running
diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py
index 46ef21c95..bf3abf834 100644
--- a/comfy/audio_encoders/audio_encoders.py
+++ b/comfy/audio_encoders/audio_encoders.py
@@ -29,7 +29,7 @@ class AudioEncoderModel():
self.model_sample_rate = 16000
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index dae9a895d..d75b9fe99 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
+parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
+parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
+
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index d5fc53497..837eba013 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -48,7 +48,7 @@ class ClipVisionModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.model, sd, strict=False)
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 0b5e30f52..890076938 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -439,7 +439,7 @@ def controlnet_config(sd, model_options={}):
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
- missing, unexpected = control_model.load_state_dict(sd, strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@@ -473,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
- missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
- missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
@@ -748,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
pass
w = WeightsLoader()
w.control_model = control_model
- missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
else:
- missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
+ missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@@ -874,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
else:
return None
- missing, unexpected = model_ad.load_state_dict(t2i_data)
+ missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))
diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py
new file mode 100644
index 000000000..6610733d6
--- /dev/null
+++ b/comfy/disk_weights.py
@@ -0,0 +1,275 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+from __future__ import annotations
+
+import collections
+import weakref
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import torch
+
+
+ALLOW_GDS = False
+PIN_IF_CPU = False
+DISK_WEIGHTS_ENABLED = False
+
+
+@dataclass
+class DiskTensorRef:
+ state_dict: object
+ key: str
+ meta: object
+ requires_grad: bool
+ is_buffer: bool
+
+ def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor:
+ dtype = getattr(self.meta, "dtype", None)
+ if hasattr(self.state_dict, "get_tensor"):
+ return self.state_dict.get_tensor(
+ self.key,
+ device=device,
+ dtype=dtype,
+ allow_gds=allow_gds,
+ pin_if_cpu=pin_if_cpu,
+ )
+ tensor = self.state_dict[self.key]
+ if device is not None and tensor.device != device:
+ tensor = tensor.to(device=device)
+ if dtype is not None and tensor.dtype != dtype:
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+
+
+class DiskWeightRegistry:
+ def __init__(self):
+ self._registry = weakref.WeakKeyDictionary()
+
+ def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef):
+ module_refs = self._registry.setdefault(module, {})
+ module_refs[name] = ref
+
+ def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]:
+ return self._registry.get(module)
+
+ def has(self, module: torch.nn.Module) -> bool:
+ return module in self._registry
+
+
+@dataclass
+class CacheEntry:
+ module_ref: weakref.ReferenceType
+ name: str
+ size_bytes: int
+ is_buffer: bool
+
+
+class DiskWeightCache:
+ def __init__(self, max_bytes: int = 0):
+ self.max_bytes = max_bytes
+ self.current_bytes = 0
+ self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict()
+
+ def set_limit(self, max_bytes: int):
+ self.max_bytes = max_bytes
+ self._evict_if_needed()
+
+ def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]:
+ return (id(module), name)
+
+ def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
+ if tensor.device.type != "cpu":
+ return
+ size_bytes = tensor.numel() * tensor.element_size()
+ key = self._entry_key(module, name)
+ if key in self._entries:
+ entry = self._entries.pop(key)
+ self.current_bytes -= entry.size_bytes
+ module_ref = weakref.ref(module, self._drop_module_entries)
+ self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
+ self.current_bytes += size_bytes
+ self._evict_if_needed()
+
+ def touch(self, module: torch.nn.Module, name: str):
+ key = self._entry_key(module, name)
+ if key in self._entries:
+ entry = self._entries.pop(key)
+ self._entries[key] = entry
+
+ def evict_bytes(self, bytes_to_free: int):
+ freed = 0
+ while self._entries and freed < bytes_to_free:
+ _, entry = self._entries.popitem(last=False)
+ freed += entry.size_bytes
+ self.current_bytes -= entry.size_bytes
+ module = entry.module_ref()
+ if module is not None:
+ _evict_module_weight(module, entry.name, entry.is_buffer)
+ return freed
+
+ def _drop_module_entries(self, module_ref: weakref.ReferenceType):
+ to_remove = []
+ for key, entry in self._entries.items():
+ if entry.module_ref is module_ref:
+ to_remove.append(key)
+ for key in to_remove:
+ entry = self._entries.pop(key)
+ self.current_bytes -= entry.size_bytes
+
+ def _evict_if_needed(self):
+ while self._entries and self.current_bytes > self.max_bytes:
+ _, entry = self._entries.popitem(last=False)
+ self.current_bytes -= entry.size_bytes
+ module = entry.module_ref()
+ if module is not None:
+ _evict_module_weight(module, entry.name, entry.is_buffer)
+
+
+REGISTRY = DiskWeightRegistry()
+CACHE = DiskWeightCache(0)
+
+
+def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
+ global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
+ ALLOW_GDS = allow_gds
+ PIN_IF_CPU = pin_if_cpu
+ DISK_WEIGHTS_ENABLED = enabled
+ CACHE.set_limit(cache_bytes if enabled else 0)
+ if not enabled:
+ CACHE._entries.clear()
+ CACHE.current_bytes = 0
+
+
+def disk_weights_enabled() -> bool:
+ return DISK_WEIGHTS_ENABLED
+
+
+def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
+ if not disk_weights_enabled():
+ return
+ if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
+ return
+ for name, param in module.named_parameters(recurse=True):
+ key = f"{prefix}{name}" if prefix else name
+ if key in state_dict:
+ meta = state_dict.meta(key)
+ ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
+ REGISTRY.register(module, name, ref)
+ if param.device.type == "cpu":
+ CACHE.record(module, name, param, is_buffer=False)
+ for name, buf in module.named_buffers(recurse=True):
+ key = f"{prefix}{name}" if prefix else name
+ if key in state_dict and buf is not None:
+ meta = state_dict.meta(key)
+ ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
+ REGISTRY.register(module, name, ref)
+ if buf.device.type == "cpu":
+ CACHE.record(module, name, buf, is_buffer=True)
+
+
+def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
+ ref = REGISTRY.get(module)
+ if not ref or name not in ref:
+ return
+ disk_ref = ref[name]
+ shape = getattr(disk_ref.meta, "shape", None)
+ dtype = getattr(disk_ref.meta, "dtype", None)
+ if shape is None or dtype is None:
+ return
+ meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
+ if is_buffer:
+ module._buffers[name] = meta_tensor
+ else:
+ module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
+
+
+def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
+ def check(obj):
+ if torch.is_tensor(obj):
+ return obj.device
+ if isinstance(obj, (list, tuple)):
+ for item in obj:
+ dev = check(item)
+ if dev is not None:
+ return dev
+ if isinstance(obj, dict):
+ for item in obj.values():
+ dev = check(item)
+ if dev is not None:
+ return dev
+ return None
+
+ dev = check(args)
+ if dev is not None:
+ return dev
+ return check(kwargs)
+
+
+def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
+ refs = REGISTRY.get(module)
+ if not refs:
+ return
+ for name, disk_ref in refs.items():
+ if name in module._parameters:
+ current = module._parameters[name]
+ is_buffer = False
+ elif name in module._buffers:
+ current = module._buffers[name]
+ is_buffer = True
+ else:
+ continue
+ if current is None:
+ continue
+ if current.device.type != "meta":
+ if current.device.type == "cpu":
+ CACHE.touch(module, name)
+ continue
+ tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
+ if is_buffer:
+ module._buffers[name] = tensor
+ else:
+ module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
+ if tensor.device.type == "cpu":
+ CACHE.record(module, name, tensor, is_buffer=is_buffer)
+
+
+def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
+ if not REGISTRY.has(module):
+ return
+ if getattr(module, "comfy_cast_weights", False):
+ target_device = torch.device("cpu")
+ else:
+ target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
+ ensure_module_materialized(module, target_device)
+
+
+def attach_disk_weight_hooks(model: torch.nn.Module):
+ if not disk_weights_enabled():
+ return
+ for module in model.modules():
+ if getattr(module, "_disk_weight_hook_attached", False):
+ continue
+ module.register_forward_pre_hook(disk_weight_pre_hook)
+ module._disk_weight_hook_attached = True
+
+
+def evict_ram_cache(bytes_to_free: int):
+ if bytes_to_free <= 0:
+ return 0
+ return CACHE.evict_bytes(bytes_to_free)
diff --git a/comfy/gligen.py b/comfy/gligen.py
index 1d7b6c2f4..c2cf7c6db 100644
--- a/comfy/gligen.py
+++ b/comfy/gligen.py
@@ -3,6 +3,7 @@ import torch
from torch import nn
from .ldm.modules.attention import CrossAttention, FeedForward
import comfy.ops
+import comfy.utils
ops = comfy.ops.manual_cast
@@ -282,7 +283,7 @@ def load_gligen(sd):
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
- gated.load_state_dict(n_sd, strict=False)
+ comfy.utils.load_state_dict(gated, n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
@@ -293,7 +294,7 @@ def load_gligen(sd):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
- w.load_state_dict(sd, strict=False)
+ comfy.utils.load_state_dict(w, sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
index d9e76922f..0c8b9240f 100644
--- a/comfy/ldm/hunyuan_video/upsampler.py
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -1,4 +1,5 @@
import torch
+import comfy.utils
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
@@ -112,7 +113,7 @@ class HunyuanVideo15SRModel():
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=True)
+ return comfy.utils.load_state_dict(self.model, sd, strict=True)
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py
index a9111d3bd..c61883ba3 100644
--- a/comfy/ldm/lightricks/vae/audio_vae.py
+++ b/comfy/ldm/lightricks/vae/audio_vae.py
@@ -2,6 +2,7 @@ import json
from dataclasses import dataclass
import math
import torch
+import comfy.utils
import torchaudio
import comfy.model_management
@@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module):
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
- self.autoencoder.load_state_dict(vae_sd, strict=False)
- self.vocoder.load_state_dict(vocoder_sd, strict=False)
+ comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
+ comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py
index 62f24606c..1009d2d76 100644
--- a/comfy/ldm/mmaudio/vae/vae.py
+++ b/comfy/ldm/mmaudio/vae/vae.py
@@ -2,6 +2,7 @@ import logging
from typing import Optional
import torch
+import comfy.utils
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
@@ -152,7 +153,7 @@ class VAE(nn.Module):
return dec, posterior
def load_weights(self, src_dict) -> None:
- self.load_state_dict(src_dict, strict=True)
+ comfy.utils.load_state_dict(self, src_dict, strict=True)
@property
def device(self) -> torch.device:
@@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE:
if name == '44k':
return VAE_44k(**kwargs)
raise ValueError(f'Unknown model: {name}')
-
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 49efd700b..3bb155f2c 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -299,20 +299,14 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix=""):
- to_load = {}
- keys = list(sd.keys())
- for k in keys:
- if k.startswith(unet_prefix):
- to_load[k[len(unet_prefix):]] = sd.pop(k)
-
+ to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True)
to_load = self.model_config.process_unet_state_dict(to_load)
- m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
+ m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
if len(u) > 0:
logging.warning("unet unexpected: {}".format(u))
- del to_load
return self
def process_latent_in(self, latent):
@@ -751,8 +745,8 @@ class StableAudio1(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
- self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights)
- self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights)
+ utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
+ utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
def extra_conds(self, **kwargs):
out = {}
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 0853b3aec..0b3a05659 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
+def sd_shape(state_dict, key):
+ return comfy.utils.state_dict_meta(state_dict, key).shape
+
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
if len(transformer_keys) > 0:
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
- context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
- use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
+ context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1]
+ use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2
time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict
time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
@@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
unet_config = {}
- unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
- patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
+ unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1]
+ patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2]
unet_config["patch_size"] = patch_size
final_layer = '{}final_layer.linear.weight'.format(key_prefix)
if final_layer in state_dict:
- unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size)
+ unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size)
- unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64
+ unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64
unet_config["input_size"] = None
y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix)
if y_key in state_dict_keys:
- unet_config["adm_in_channels"] = state_dict[y_key].shape[1]
+ unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1]
context_key = '{}context_embedder.weight'.format(key_prefix)
if context_key in state_dict_keys:
- in_features = state_dict[context_key].shape[1]
- out_features = state_dict[context_key].shape[0]
+ in_features = sd_shape(state_dict, context_key)[1]
+ out_features = sd_shape(state_dict, context_key)[0]
unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}}
num_patches_key = '{}pos_embed'.format(key_prefix)
if num_patches_key in state_dict_keys:
- num_patches = state_dict[num_patches_key].shape[1]
+ num_patches = sd_shape(state_dict, num_patches_key)[1]
unet_config["num_patches"] = num_patches
unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches))
@@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c'
- w = state_dict[text_mapper_name]
- if w.shape[0] == 1536: #stage c lite
+ w_shape = sd_shape(state_dict, text_mapper_name)
+ if w_shape[0] == 1536: #stage c lite
unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]]
- elif w.shape[0] == 2048: #stage c full
+ elif w_shape[0] == 2048: #stage c full
unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b'
- w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)]
- if w.shape[-1] == 640:
+ w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix))
+ if w_shape[-1] == 640:
unet_config['c_hidden'] = [320, 640, 1280, 1280]
unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
- elif w.shape[-1] == 576: #stage b lite
+ elif w_shape[-1] == 576: #stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
@@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
- unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
- unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
+ unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1]
+ unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1]
double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.')
single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.')
unet_config["n_double_layers"] = double_layers
@@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
- unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
+ unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
- if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
+ if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
@@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
- in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
- out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
+ in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix))
+ out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix))
dit_config["image_model"] = "hunyuan_video"
- dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
- dit_config["patch_size"] = list(in_w.shape[2:])
- dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
+ dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels
+ dit_config["patch_size"] = list(in_w_shape[2:])
+ dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
@@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["meanflow"] = False
- dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
- dit_config["hidden_size"] = in_w.shape[0]
+ dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1]
+ dit_config["hidden_size"] = in_w_shape[0]
dit_config["mlp_ratio"] = 4.0
- dit_config["num_heads"] = in_w.shape[0] // 128
+ dit_config["num_heads"] = in_w_shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["theta"] = 256
@@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
- dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
+ dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
@@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
- w = state_dict[in_key]
- dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
- dit_config["hidden_size"] = w.shape[0]
+ w_shape = sd_shape(state_dict, in_key)
+ dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size)
+ dit_config["hidden_size"] = w_shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
- w = state_dict[txt_in_key]
- dit_config["context_in_dim"] = w.shape[1]
- dit_config["hidden_size"] = w.shape[0]
+ w_shape = sd_shape(state_dict, txt_in_key)
+ dit_config["context_in_dim"] = w_shape[1]
+ dit_config["hidden_size"] = w_shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
- dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
+ dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1]
else:
dit_config["vec_in_dim"] = None
@@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
- shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
+ shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix))
dit_config["attention_head_dim"] = shape[0] // 32
dit_config["cross_attention_dim"] = shape[1]
if metadata is not None and "config" in metadata:
@@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_key = "{}y_embedder.y_embedding".format(key_prefix)
if y_key in state_dict_keys:
- dit_config["model_max_length"] = state_dict[y_key].shape[0]
+ dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0]
pe_key = "{}pos_embed".format(key_prefix)
if pe_key in state_dict_keys:
- dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
+ dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
@@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
- dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
+ dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
- dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
+ dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0]
dit_config["block_config"] = "FA-CA-MLP"
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
@@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
- w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
- dit_config["dim"] = w.shape[0]
- dit_config["cap_feat_dim"] = w.shape[1]
+ w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix))
+ dit_config["dim"] = w_shape[0]
+ dit_config["cap_feat_dim"] = w_shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["qk_norm"] = True
@@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
- ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
- if ctd_weight is not None: # NewBie
- dit_config["clip_text_dim"] = ctd_weight.shape[0]
+ ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix)
+ if ctd_key in state_dict_keys: # NewBie
+ dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0]
# NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
@@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
- dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
- out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4
+ dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1]
+ out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4
dit_config["dim"] = dim
dit_config["out_dim"] = out_dim
dit_config["num_heads"] = dim // 128
- dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
+ dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["patch_size"] = (1, 2, 2)
dit_config["freq_dim"] = 256
@@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["qk_norm"] = True
dit_config["cross_attn_norm"] = True
dit_config["eps"] = 1e-6
- dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
+ dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1]
if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "vace"
- dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
+ dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
@@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
- flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
- if flf_weight is not None:
- dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
+ flf_key = '{}img_emb.emb_pos'.format(key_prefix)
+ if flf_key in state_dict_keys:
+ dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1]
- ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
- if ref_conv_weight is not None:
- dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
+ ref_conv_key = '{}ref_conv.weight'.format(key_prefix)
+ if ref_conv_key in state_dict_keys:
+ dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
- in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
+ in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix))
dit_config = {}
dit_config["image_model"] = "hunyuan3d2"
dit_config["in_channels"] = in_shape[1]
- dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
+ dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1]
dit_config["hidden_size"] = in_shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
@@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
- dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
+ dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1]
dit_config["context_dim"] = 1024
- dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
+ dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
@@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
concat_padding_mask = True
- dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
+ dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
- dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
+ dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
@@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
- dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
+ dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
@@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
- model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
@@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
- dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
+ dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0]
dit_config["image_model"] = "kandinsky5"
- dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
- dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
+ dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0]
+ dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
@@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
y_input = '{}label_emb.0.0.weight'.format(key_prefix)
if y_input in state_dict_keys:
unet_config["num_classes"] = "sequential"
- unet_config["adm_in_channels"] = state_dict[y_input].shape[1]
+ unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1]
else:
unet_config["adm_in_channels"] = None
- model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
- in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
+ model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0]
+ in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
- out_channels = state_dict[out_key].shape[0]
+ out_channels = sd_shape(state_dict, out_key)[0]
else:
out_channels = 4
@@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
if res_block_prefix in block_keys:
last_res_blocks += 1
- last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
+ last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
if out is not None:
@@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
- match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
+ match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1]
attn_res *= 2
if attn_blocks == 0:
@@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
match["transformer_depth"] = transformer_depth
- match["model_channels"] = state_dict["conv_in.weight"].shape[0]
- match["in_channels"] = state_dict["conv_in.weight"].shape[1]
+ match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0]
+ match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1]
match["adm_in_channels"] = None
if "class_embedding.linear_1.weight" in state_dict:
- match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
+ match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1]
elif "add_embedding.linear_1.weight" in state_dict:
- match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
+ match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
@@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
elif 'x_embedder.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
- hidden_size = state_dict["x_embedder.bias"].shape[0]
+ hidden_size = sd_shape(state_dict, "x_embedder.bias")[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
- depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
+ depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
else:
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 928282092..f389ad857 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -27,6 +27,7 @@ import platform
import weakref
import gc
import os
+import comfy.disk_weights
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -583,6 +584,8 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
+ if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
+ comfy.disk_weights.evict_ram_cache(memory_required)
unloaded_model = []
can_unload = []
unloaded_models = []
@@ -1124,6 +1127,16 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
+WEIGHTS_RAM_CACHE_BYTES = 0
+WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
+if args.weights_ram_cache_gb is not None:
+ WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
+ comfy.disk_weights.configure(
+ WEIGHTS_RAM_CACHE_BYTES,
+ allow_gds=WEIGHTS_GDS_ENABLED,
+ pin_if_cpu=not args.disable_pinned_memory,
+ )
+
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 93d26c690..0e38629bc 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -34,6 +34,7 @@ import comfy.lora
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
+import comfy.disk_weights
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@@ -269,6 +270,8 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
+ comfy.disk_weights.attach_disk_weight_hooks(self.model)
+
def model_size(self):
if self.size > 0:
return self.size
@@ -1356,4 +1359,3 @@ class ModelPatcher:
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)
-
diff --git a/comfy/safetensors_stream.py b/comfy/safetensors_stream.py
new file mode 100644
index 000000000..475bf6bf0
--- /dev/null
+++ b/comfy/safetensors_stream.py
@@ -0,0 +1,799 @@
+"""
+ This file is part of ComfyUI.
+ Copyright (C) 2024 Comfy
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+"""
+
+from __future__ import annotations
+
+import collections
+import importlib
+import importlib.util
+import os
+import threading
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple
+
+import torch
+
+
+_FST_MODULE = None
+_FST_LOCK = threading.Lock()
+_FST_LOADED = False
+_GDS_INITIALIZED = False
+_MISSING = object()
+
+
+def _require_fastsafetensors():
+ global _FST_MODULE
+ with _FST_LOCK:
+ if _FST_MODULE is None:
+ if importlib.util.find_spec("fastsafetensors") is None:
+ raise ImportError(
+ "fastsafetensors is required for safetensors streaming. "
+ "Install it with: pip install 'fastsafetensors @ https://github.com/"
+ "foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'"
+ )
+ _FST_MODULE = importlib.import_module("fastsafetensors")
+ return _FST_MODULE
+
+
+def _init_fastsafetensors_lib():
+ global _FST_LOADED
+ fst = _require_fastsafetensors()
+ if not _FST_LOADED:
+ fst.cpp.load_library_functions()
+ _FST_LOADED = True
+ return fst
+
+
+def _init_gds():
+ global _GDS_INITIALIZED
+ fst = _init_fastsafetensors_lib()
+ if not _GDS_INITIALIZED:
+ if fst.cpp.init_gds() != 0:
+ raise RuntimeError("fastsafetensors init_gds() failed")
+ _GDS_INITIALIZED = True
+
+
+@dataclass(frozen=True)
+class TensorMeta:
+ dtype: torch.dtype
+ shape: Tuple[int, ...]
+ numel: int
+ nbytes: int
+ data_offsets: Tuple[int, int]
+ filename: str
+ fst_dtype: object
+ strides: Tuple[int, ...]
+
+
+class SafeTensorIndex:
+ def __init__(self, filename: str):
+ fst = _init_fastsafetensors_lib()
+ framework = fst.frameworks.get_framework_op("pytorch")
+ metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework)
+ self._filename = filename
+ self._metadata = metadata
+ self._framework = framework
+ from fastsafetensors.frameworks import _torch as fst_torch
+ self._dtype_map = fst_torch.dtype_convert
+ self._tensor_meta: Dict[str, TensorMeta] = {}
+ for key, frame in metadata.tensors.items():
+ torch_dtype = self._dtype_map.get(frame.dtype, None)
+ if torch_dtype is None:
+ raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}")
+ numel = 1
+ for s in frame.shape:
+ numel *= s
+ nbytes = numel * framework.get_dtype_size(frame.dtype)
+ self._tensor_meta[key] = TensorMeta(
+ dtype=torch_dtype,
+ shape=tuple(frame.shape),
+ numel=numel,
+ nbytes=nbytes,
+ data_offsets=(frame.data_offsets[0], frame.data_offsets[1]),
+ filename=filename,
+ fst_dtype=frame.dtype,
+ strides=tuple(frame.strides),
+ )
+
+ def keys(self) -> Iterable[str]:
+ return self._tensor_meta.keys()
+
+ def has(self, key: str) -> bool:
+ return key in self._tensor_meta
+
+ def meta(self, key: str) -> TensorMeta:
+ return self._tensor_meta[key]
+
+ def metadata(self):
+ return self._metadata.metadata
+
+ @property
+ def header_length(self) -> int:
+ return self._metadata.header_length
+
+ @property
+ def size_bytes(self) -> int:
+ return self._metadata.size_bytes
+
+
+class _SafeTensorFile:
+ def __init__(self, filename: str, index: SafeTensorIndex):
+ self.filename = filename
+ self.index = index
+ self._fd: Optional[int] = None
+ self._gds_handle = None
+ self._gds_reader = None
+ self._nogds_reader = None
+ self._refcount = 1
+
+ def acquire(self) -> "_SafeTensorFile":
+ self._refcount += 1
+ return self
+
+ def release(self):
+ self._refcount -= 1
+ if self._refcount <= 0:
+ self.close()
+
+ def close(self):
+ if self._fd is not None:
+ os.close(self._fd)
+ self._fd = None
+ self._gds_handle = None
+
+ def _ensure_fd(self) -> int:
+ if self._fd is None:
+ self._fd = os.open(self.filename, os.O_RDONLY, 0o644)
+ return self._fd
+
+ def _ensure_nogds_reader(self, use_cuda: bool):
+ fst = _init_fastsafetensors_lib()
+ if self._nogds_reader is None:
+ self._nogds_reader = fst.cpp.nogds_file_reader(
+ False, 16 * 1024, 16, use_cuda
+ )
+ return self._nogds_reader
+
+ def _ensure_gds_reader(self, use_cuda: bool):
+ fst = _init_fastsafetensors_lib()
+ if self._gds_reader is None:
+ self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda)
+ return self._gds_reader
+
+ def _ensure_gds_handle(self, use_cuda: bool):
+ if self._gds_handle is None:
+ fst = _init_fastsafetensors_lib()
+ framework = fst.frameworks.get_framework_op("pytorch")
+ o_direct = _get_gds_o_direct(framework)
+ self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda)
+ return self._gds_handle
+
+ def read_tensor(
+ self,
+ meta: TensorMeta,
+ device: torch.device,
+ dtype: Optional[torch.dtype],
+ allow_gds: bool,
+ pin_if_cpu: bool,
+ ) -> torch.Tensor:
+ fst = _init_fastsafetensors_lib()
+ framework = fst.frameworks.get_framework_op("pytorch")
+ device_is_cuda = device.type == "cuda"
+ if device_is_cuda and allow_gds:
+ _ensure_gds_ready(device)
+ tensor = self._read_tensor_gds(
+ fst, framework, meta, device, dtype
+ )
+ return tensor
+
+ cpu_tensor = self._read_tensor_nogds(
+ fst, framework, meta, torch.device("cpu"), dtype
+ )
+ if device_is_cuda:
+ if pin_if_cpu:
+ cpu_tensor = cpu_tensor.pin_memory()
+ gpu_tensor = torch.empty_like(cpu_tensor, device=device)
+ gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
+ return gpu_tensor
+ return cpu_tensor
+
+ def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]:
+ fst = _init_fastsafetensors_lib()
+ align = fst.cpp.get_alignment_size()
+ aligned_offset = (abs_start // align) * align
+ head = abs_start - aligned_offset
+ aligned_length = length + head
+ tail = aligned_length % align
+ if tail:
+ aligned_length += align - tail
+ return aligned_offset, aligned_length, head
+
+ def _read_tensor_nogds(
+ self,
+ fst,
+ framework,
+ meta: TensorMeta,
+ device: torch.device,
+ dtype: Optional[torch.dtype],
+ ) -> torch.Tensor:
+ fd = self._ensure_fd()
+ reader = self._ensure_nogds_reader(use_cuda=False)
+ abs_start = self.index.header_length + meta.data_offsets[0]
+ length = meta.data_offsets[1] - meta.data_offsets[0]
+ aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
+ ptr_align = framework.get_device_ptr_align()
+ buffer_length = aligned_length + ptr_align
+ buf_ptr = fst.cpp.cpu_malloc(buffer_length)
+ gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False)
+ ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
+ req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off)
+ if reader.wait_read(req) < 0:
+ fst.cpp.cpu_free(buf_ptr)
+ raise RuntimeError("nogds_file_reader read failed")
+ owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr))
+ tensor = _dlpack_tensor_from_buffer(
+ fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
+ )
+ if dtype is not None and dtype != tensor.dtype:
+ _validate_dtype_conversion(tensor.dtype, dtype)
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+
+ def _read_tensor_gds(
+ self,
+ fst,
+ framework,
+ meta: TensorMeta,
+ device: torch.device,
+ dtype: Optional[torch.dtype],
+ ) -> torch.Tensor:
+ reader = self._ensure_gds_reader(use_cuda=True)
+ handle = self._ensure_gds_handle(use_cuda=True)
+ abs_start = self.index.header_length + meta.data_offsets[0]
+ length = meta.data_offsets[1] - meta.data_offsets[0]
+ aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
+ ptr_align = framework.get_device_ptr_align()
+ buffer_length = aligned_length + ptr_align
+ fst_device = _fst_device_from_torch(fst, device)
+ gbuf = framework.alloc_tensor_memory(buffer_length, fst_device)
+ ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align
+ file_length = self.index.size_bytes
+ req = reader.submit_read(
+ handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length
+ )
+ if reader.wait_read(req) < 0:
+ framework.free_tensor_memory(gbuf, fst_device)
+ raise RuntimeError("gds_file_reader read failed")
+ owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device))
+ tensor = _dlpack_tensor_from_buffer(
+ fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
+ )
+ if dtype is not None and dtype != tensor.dtype:
+ _validate_dtype_conversion(tensor.dtype, dtype)
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+
+
+def _fst_device_from_torch(fst, device: torch.device):
+ if device.type == "cuda" and device.index is not None:
+ return fst.st_types.Device.from_str(f"cuda:{device.index}")
+ return fst.st_types.Device.from_str(device.type)
+
+
+class _BufferOwner:
+ def __init__(self, free_fn):
+ self._free_fn = free_fn
+
+ def __del__(self):
+ try:
+ self._free_fn()
+ except Exception:
+ pass
+
+
+def _dlpack_tensor_from_buffer(
+ fst,
+ framework,
+ ptr: int,
+ meta: TensorMeta,
+ device: torch.device,
+ owner: Optional[_BufferOwner],
+) -> torch.Tensor:
+ disk_dtype = framework.as_workaround_dtype(meta.fst_dtype)
+ dev = _fst_device_from_torch(fst, device)
+ dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev)
+ torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor
+ if disk_dtype != meta.fst_dtype:
+ torch_tensor = torch_tensor.view(meta.dtype)
+ if owner is not None:
+ torch_tensor._comfy_disk_buffer_owner = owner
+ return torch_tensor
+
+
+def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
+ if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
+ raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
+
+
+def _get_gds_o_direct(framework) -> bool:
+ cuda_ver = framework.get_cuda_ver()
+ if cuda_ver and cuda_ver != "0.0":
+ ver_parts = cuda_ver.split("-", 1)
+ if len(ver_parts) == 2:
+ cudavers = list(map(int, ver_parts[1].split(".")))
+ if ver_parts[0] == "cuda":
+ return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
+ return True
+ return True
+
+
+def _ensure_gds_ready(device: torch.device):
+ fst = _init_fastsafetensors_lib()
+ if not fst.common.is_gpu_found():
+ raise RuntimeError(
+ "GPUDirect requested but GPU runtime library is missing. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0)
+ if gds_supported < 0:
+ raise RuntimeError(
+ "GPUDirect requested but is_gds_supported() failed. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ if not fst.cpp.is_cufile_found():
+ raise RuntimeError(
+ "GPUDirect requested but libcufile is missing. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ if gds_supported == 0:
+ raise RuntimeError(
+ "GPUDirect requested but GDS is unsupported on this platform. "
+ "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU."
+ )
+ _init_gds()
+
+
+class StreamStateDict(collections.abc.MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(
+ self,
+ index: SafeTensorIndex,
+ file: _SafeTensorFile,
+ device: torch.device,
+ allow_gds: bool = False,
+ ):
+ self._index = index
+ self._file = file
+ self._device = device
+ self._allow_gds = allow_gds
+ self._overrides: Dict[str, torch.Tensor] = {}
+ self._deleted: set[str] = set()
+
+ @classmethod
+ def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict":
+ index = SafeTensorIndex(filename)
+ file = _SafeTensorFile(filename, index)
+ return cls(index, file, device, allow_gds=allow_gds)
+
+ def close(self):
+ if self._file is not None:
+ self._file.release()
+ self._file = None
+
+ def __del__(self):
+ try:
+ self.close()
+ except Exception:
+ pass
+
+ def meta(self, key: str) -> TensorMeta:
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return TensorMeta(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ data_offsets=(0, numel * t.element_size()),
+ filename="",
+ fst_dtype=None,
+ strides=tuple(t.stride()),
+ )
+ if key in self._deleted:
+ raise KeyError(key)
+ return self._index.meta(key)
+
+ def get_tensor(
+ self,
+ key: str,
+ *,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ ) -> torch.Tensor:
+ if key in self._overrides:
+ t = self._overrides[key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if dtype is not None and t.dtype != dtype:
+ _validate_dtype_conversion(t.dtype, dtype)
+ t = t.to(dtype=dtype)
+ return t
+ if key in self._deleted:
+ raise KeyError(key)
+ if device is None:
+ device = self._device
+ if allow_gds is None:
+ allow_gds = self._allow_gds
+ meta = self._index.meta(key)
+ return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ return self.get_tensor(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ if key in self._deleted:
+ raise KeyError(key)
+ if self._index.has(key):
+ self._deleted.add(key)
+ return
+ raise KeyError(key)
+
+ def __iter__(self) -> Iterator[str]:
+ for k in self._index.keys():
+ if k in self._deleted:
+ continue
+ if k in self._overrides:
+ continue
+ yield k
+ for k in self._overrides.keys():
+ yield k
+
+ def __len__(self) -> int:
+ base = len(self._index.keys())
+ return base - len(self._deleted) + len(self._overrides)
+
+ def __contains__(self, key: object) -> bool:
+ if not isinstance(key, str):
+ return False
+ if key in self._deleted:
+ return False
+ if key in self._overrides:
+ return True
+ return self._index.has(key)
+
+ def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides.pop(key)
+ if key in self._deleted:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ if self._index.has(key):
+ self._deleted.add(key)
+ return self.get_tensor(key)
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+
+ def copy(self) -> "StreamStateDict":
+ new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds)
+ new._overrides = dict(self._overrides)
+ new._deleted = set(self._deleted)
+ return new
+
+ def metadata(self):
+ return self._index.metadata()
+
+
+class _BaseViewStateDict(MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(self, base: MutableMapping, mutate_base: bool = False):
+ self._base = base
+ self._mutate_base = mutate_base
+ self._overrides: Dict[str, torch.Tensor] = {}
+ self._deleted: set[str] = set()
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return key
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._base.keys()
+
+ def get_tensor(
+ self,
+ key: str,
+ *,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ allow_gds: Optional[bool] = None,
+ pin_if_cpu: bool = False,
+ ) -> torch.Tensor:
+ if key in self._overrides:
+ t = self._overrides[key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if dtype is not None and t.dtype != dtype:
+ _validate_dtype_conversion(t.dtype, dtype)
+ t = t.to(dtype=dtype)
+ return t
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if hasattr(self._base, "get_tensor"):
+ return self._base.get_tensor(
+ base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu
+ )
+ t = self._base[base_key]
+ if device is not None and t.device != device:
+ t = t.to(device=device)
+ if dtype is not None and t.dtype != dtype:
+ _validate_dtype_conversion(t.dtype, dtype)
+ t = t.to(dtype=dtype)
+ return t
+
+ def meta(self, key: str):
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if hasattr(self._base, "meta"):
+ return self._base.meta(base_key)
+ t = self._base[base_key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ return self.get_tensor(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ base_key = self._resolve_base_key(key)
+ if self._mutate_base and base_key is not None and base_key in self._base:
+ self._base[base_key] = value
+ else:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ raise KeyError(key)
+ if self._mutate_base and base_key in self._base:
+ del self._base[base_key]
+ else:
+ self._deleted.add(key)
+
+ def __iter__(self) -> Iterator[str]:
+ for k in self._iter_base_keys():
+ if k in self._deleted:
+ continue
+ yield k
+ for k in self._overrides.keys():
+ yield k
+
+ def __len__(self) -> int:
+ base_keys = list(self._iter_base_keys())
+ return len(base_keys) - len(self._deleted) + len(self._overrides)
+
+ def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides.pop(key)
+ base_key = self._resolve_base_key(key)
+ if base_key is None or key in self._deleted:
+ if default is _MISSING:
+ raise KeyError(key)
+ return default
+ if self._mutate_base:
+ try:
+ return self._base.pop(base_key)
+ except KeyError:
+ if default is _MISSING:
+ raise
+ return default
+ self._deleted.add(key)
+ return self.get_tensor(key)
+
+
+class FilterViewStateDict(_BaseViewStateDict):
+ def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False):
+ super().__init__(base, mutate_base=mutate_base)
+ self._predicate = predicate
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ if self._predicate(key):
+ return key
+ return None
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ for k in self._base.keys():
+ if self._predicate(k):
+ yield k
+
+
+class PrefixViewStateDict(_BaseViewStateDict):
+ def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False):
+ super().__init__(base, mutate_base=mutate_base)
+ self._source_prefix = source_prefix
+ self._target_prefix = target_prefix
+ self._mapping: Dict[str, str] = {}
+ self._reverse: Dict[str, str] = {}
+ for k in base.keys():
+ if not k.startswith(source_prefix):
+ continue
+ view_key = f"{target_prefix}{k[len(source_prefix):]}"
+ self._mapping[k] = view_key
+ self._reverse[view_key] = k
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return self._reverse.get(key)
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._reverse.keys()
+
+
+class RenameViewStateDict(_BaseViewStateDict):
+ def __init__(
+ self,
+ base: MutableMapping,
+ replace_prefix: Mapping[str, str],
+ filter_keys: bool = False,
+ mutate_base: bool = False,
+ ):
+ super().__init__(base, mutate_base=mutate_base)
+ self._filter_keys = filter_keys
+ self._replace = list(replace_prefix.items())
+ self._mapping: Dict[str, str] = {}
+ self._reverse: Dict[str, str] = {}
+ for k in base.keys():
+ view_key = self._replace_key(k)
+ if view_key is None:
+ continue
+ self._mapping[k] = view_key
+ self._reverse[view_key] = k
+
+ def _replace_key(self, key: str) -> Optional[str]:
+ for rp, dst in self._replace:
+ if key.startswith(rp):
+ return f"{dst}{key[len(rp):]}"
+ if self._filter_keys:
+ return None
+ return key
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return self._reverse.get(key)
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._reverse.keys()
+
+
+class MergedStateDict(MutableMapping):
+ is_stream_state_dict = True
+
+ def __init__(self, *mappings: MutableMapping):
+ self._mappings = list(mappings)
+ self._overrides: Dict[str, torch.Tensor] = {}
+ self._deleted: set[str] = set()
+
+ def __getitem__(self, key: str) -> torch.Tensor:
+ if key in self._overrides:
+ return self._overrides[key]
+ if key in self._deleted:
+ raise KeyError(key)
+ for mapping in reversed(self._mappings):
+ if key in mapping:
+ if hasattr(mapping, "get_tensor"):
+ return mapping.get_tensor(key)
+ return mapping[key]
+ raise KeyError(key)
+
+ def __setitem__(self, key: str, value: torch.Tensor) -> None:
+ self._overrides[key] = value
+ self._deleted.discard(key)
+
+ def __delitem__(self, key: str) -> None:
+ if key in self._overrides:
+ del self._overrides[key]
+ return
+ if key in self._deleted:
+ raise KeyError(key)
+ if any(key in mapping for mapping in self._mappings):
+ self._deleted.add(key)
+ return
+ raise KeyError(key)
+
+ def __iter__(self) -> Iterator[str]:
+ seen = set()
+ for mapping in self._mappings:
+ for key in mapping.keys():
+ if key in self._deleted or key in seen:
+ continue
+ seen.add(key)
+ yield key
+ for key in self._overrides.keys():
+ if key not in seen:
+ yield key
+
+ def __len__(self) -> int:
+ return len(list(self.__iter__()))
+
+ def meta(self, key: str):
+ if key in self._overrides:
+ t = self._overrides[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ if key in self._deleted:
+ raise KeyError(key)
+ for mapping in reversed(self._mappings):
+ if key in mapping:
+ if hasattr(mapping, "meta"):
+ return mapping.meta(key)
+ t = mapping[key]
+ numel = t.numel()
+ return SimpleNamespace(
+ dtype=t.dtype,
+ shape=tuple(t.shape),
+ numel=numel,
+ nbytes=numel * t.element_size(),
+ )
+ raise KeyError(key)
+
+
+class MappedStateDict(_BaseViewStateDict):
+ def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False):
+ super().__init__(base, mutate_base=mutate_base)
+ self._base_to_view = dict(key_map)
+ self._view_to_base = {v: k for k, v in key_map.items()}
+
+ def _resolve_base_key(self, key: str) -> Optional[str]:
+ return self._view_to_base.get(key)
+
+ def _iter_base_keys(self) -> Iterable[str]:
+ return self._view_to_base.keys()
diff --git a/comfy/sd.py b/comfy/sd.py
index 32157e18b..7b65af39d 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -25,6 +25,7 @@ import math
import os
import comfy.utils
+import comfy.safetensors_stream
from . import clip_vision
from . import gligen
@@ -288,7 +289,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
- return self.cond_stage_model.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@@ -346,7 +347,7 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
- self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
+ self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
@@ -361,25 +362,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
- new_sd = {}
- for k in sd:
- new_sd["encoder.{}".format(k)] = sd[k]
- sd = new_sd
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."})
elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade
self.first_stage_model = StageC_coder()
self.latent_channels = 16
- new_sd = {}
- for k in sd:
- new_sd["previewer.{}".format(k)] = sd[k]
- sd = new_sd
+ sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."})
elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd:
- if sd['decoder.conv_in.weight'].shape[1] == 64:
+ if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@@ -389,9 +384,9 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
- elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
+ elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@@ -414,7 +409,7 @@ class VAE:
self.downscale_ratio = 4
self.upscale_ratio = 4
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
@@ -427,7 +422,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd:
- self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
+ self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
else:
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
@@ -462,11 +457,11 @@ class VAE:
self.downscale_index_formula = (6, 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
- tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
+ tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight")
version = 0
- if tensor_conv1.shape[0] == 512:
+ if tensor_conv1_shape[0] == 512:
version = 0
- elif tensor_conv1.shape[0] == 1024:
+ elif tensor_conv1_shape[0] == 1024:
version = 1
if "encoder.down_blocks.1.conv.conv.bias" in sd:
version = 2
@@ -483,9 +478,9 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
- elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
+ elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
- ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
+ ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@@ -509,8 +504,8 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
- self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
- self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
+ self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1]
+ self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1])
#This is likely to significantly over-estimate with single image or low frame counts as the
#implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
@@ -543,14 +538,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
- dim = sd["decoder.head.0.gamma"].shape[0]
+ dim = sd_shape(sd, "decoder.head.0.gamma")[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
- self.output_channels = sd["encoder.conv1.weight"].shape[1]
+ self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
@@ -626,7 +621,7 @@ class VAE:
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
- self.latent_channels = sd["decoder.1.weight"].shape[1]
+ self.latent_channels = sd_shape(sd, "decoder.1.weight")[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
@@ -637,12 +632,12 @@ class VAE:
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
- elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
+ elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
- if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
+ if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format=comfy.latent_formats.HunyuanVideo
else:
latent_format=None # lighttaew2_1 doesn't need scaling
@@ -662,7 +657,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
- m, u = self.first_stage_model.load_state_dict(sd, strict=False)
+ m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@@ -983,9 +978,12 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
- model.load_state_dict(model_data)
+ comfy.utils.load_state_dict(model, model_data, strict=True)
return StyleModel(model)
+def sd_shape(state_dict, key):
+ return comfy.utils.state_dict_meta(state_dict, key).shape
+
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
@@ -1055,16 +1053,16 @@ def detect_te_model(sd):
if "model.encoder.layers.0.mixer.Wqkv.weight" in sd:
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
- weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
- if weight.shape[-1] == 4096:
+ weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight")
+ if weight_shape[-1] == 4096:
return TEModel.T5_XXL
- elif weight.shape[-1] == 2048:
+ elif weight_shape[-1] == 2048:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
- weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
- if weight.shape[0] == 384:
+ weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight')
+ if weight_shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
@@ -1074,19 +1072,19 @@ def detect_te_model(sd):
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
- weight = sd['model.layers.0.self_attn.k_proj.bias']
- if weight.shape[0] == 256:
+ weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias')
+ if weight_shape[0] == 256:
return TEModel.QWEN25_3B
- if weight.shape[0] == 512:
+ if weight_shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
- weight = sd['model.layers.0.post_attention_layernorm.weight']
+ weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight')
if 'model.layers.0.self_attn.q_norm.weight' in sd:
- if weight.shape[0] == 2560:
+ if weight_shape[0] == 2560:
return TEModel.QWEN3_4B
- elif weight.shape[0] == 2048:
+ elif weight_shape[0] == 2048:
return TEModel.QWEN3_2B
- if weight.shape[0] == 5120:
+ if weight_shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
@@ -1415,19 +1413,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
- out_sd = {}
- for k in sd:
- skip = False
+ if comfy.utils.is_stream_state_dict(sd):
+ def _keep_key(k, prefixes=tuple(scaled_fp8_list)):
+ return not any(k.startswith(pref) for pref in prefixes)
+ out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False)
+ merged = out_sd
for pref in scaled_fp8_list:
- skip = skip or k.startswith(pref)
- if not skip:
- out_sd[k] = sd[k]
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd)
+ sd = merged
+ else:
+ out_sd = {}
+ for k in sd:
+ skip = False
+ for pref in scaled_fp8_list:
+ skip = skip or k.startswith(pref)
+ if not skip:
+ out_sd[k] = sd[k]
- for pref in scaled_fp8_list:
- quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
- for k in quant_sd:
- out_sd[k] = quant_sd[k]
- sd = out_sd
+ for pref in scaled_fp8_list:
+ quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
+ for k in quant_sd:
+ out_sd[k] = quant_sd[k]
+ sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
@@ -1505,12 +1513,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
- new_sd = {}
- for k in diffusers_keys:
- if k in sd:
- new_sd[diffusers_keys[k]] = sd.pop(k)
- else:
- logging.warning("{} {}".format(diffusers_keys[k], k))
+ if comfy.utils.is_stream_state_dict(sd):
+ new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys)
+ else:
+ new_sd = {}
+ for k in diffusers_keys:
+ if k in sd:
+ new_sd[diffusers_keys[k]] = sd.pop(k)
+ else:
+ logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index c512ca5d0..9fb8a8f5c 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
- return self.transformer.load_state_dict(sd, strict=False)
+ return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
def parse_parentheses(string):
result = []
@@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
try:
if embed_path.lower().endswith(".safetensors"):
- import safetensors.torch
- embed = safetensors.torch.load_file(embed_path, device="cpu")
+ embed = comfy.utils.load_torch_file(embed_path, safe_load=True)
else:
try:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py
index ce36f1a84..7cf040588 100644
--- a/comfy/taesd/taesd.py
+++ b/comfy/taesd/taesd.py
@@ -56,9 +56,9 @@ class TAESD(nn.Module):
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
- self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
+ comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
if decoder_path is not None:
- self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
+ comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
@staticmethod
def scale_latents(x):
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 130ebaeae..c998061b4 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
- return self.load_state_dict(sdo, strict=False)
+ return comfy.utils.load_state_dict(self, sdo, strict=False)
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
diff --git a/comfy/utils.py b/comfy/utils.py
index ffa98c9b1..7281dcd7e 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -26,10 +26,13 @@ import numpy as np
from PIL import Image
import logging
import itertools
+from types import SimpleNamespace
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
+from . import safetensors_stream
+import comfy.disk_weights
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
- with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
- sd = {}
- for k in f.keys():
- tensor = f.get_tensor(k)
- if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
- tensor = tensor.to(device=device, copy=True)
- sd[k] = tensor
- if return_metadata:
- metadata = f.metadata()
+ sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device)
+ if return_metadata:
+ metadata = sd.metadata()
except Exception as e:
if len(e.args) > 0:
message = e.args[0]
@@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
- w = sd[k]
- params += w.nelement()
+ meta = state_dict_meta(sd, k)
+ params += meta.numel
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
- w = sd[k]
- dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel()
+ meta = state_dict_meta(sd, k)
+ dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel
if len(dtypes) == 0:
return None
@@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace):
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
+ if is_stream_state_dict(state_dict):
+ return safetensors_stream.RenameViewStateDict(
+ state_dict,
+ replace_prefix,
+ filter_keys=filter_keys,
+ mutate_base=not filter_keys,
+ )
if filter_keys:
out = {}
else:
@@ -145,6 +149,77 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
return out
+def is_stream_state_dict(state_dict) -> bool:
+ return getattr(state_dict, "is_stream_state_dict", False)
+
+
+def state_dict_meta(state_dict, key):
+ if hasattr(state_dict, "meta"):
+ return state_dict.meta(key)
+ w = state_dict[key]
+ numel = w.numel()
+ return SimpleNamespace(
+ dtype=w.dtype,
+ shape=tuple(w.shape),
+ numel=numel,
+ nbytes=numel * w.element_size(),
+ )
+
+
+def load_state_dict(model, state_dict, strict=False, assign=False):
+ if is_stream_state_dict(state_dict):
+ comfy.disk_weights.register_module_weights(model, state_dict)
+ comfy.disk_weights.attach_disk_weight_hooks(model)
+ missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
+ return missing, unexpected
+ return model.load_state_dict(state_dict, strict=strict)
+
+
+def stream_load_state_dict(model, state_dict, strict=False, assign=False):
+ if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
+ state_dict = state_dict.copy()
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ metadata = getattr(state_dict, "_metadata", None)
+
+ def load(module, local_state_dict, prefix=""):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ if assign:
+ local_metadata["assign_to_params_buffers"] = assign
+ module._load_from_state_dict(
+ local_state_dict,
+ prefix,
+ local_metadata,
+ True,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ )
+ for name, child in module._modules.items():
+ if child is not None:
+ child_prefix = f"{prefix}{name}."
+ child_state_dict = safetensors_stream.FilterViewStateDict(
+ local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
+ )
+ load(child, child_state_dict, child_prefix)
+ incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
+ for hook in module._load_state_dict_post_hooks.values():
+ out = hook(module, incompatible)
+ if out is not None:
+ raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
+
+ load(model, state_dict)
+ if strict:
+ if len(unexpected_keys) > 0:
+ error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
+ if len(missing_keys) > 0:
+ error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
+ if len(error_msgs) > 0:
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
+ return missing_keys, unexpected_keys
+
+
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",
@@ -1217,46 +1292,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
- scaled_fp8_weight = state_dict[scaled_fp8_key]
- scaled_fp8_dtype = scaled_fp8_weight.dtype
+ if is_stream_state_dict(state_dict):
+ scaled_meta = state_dict_meta(state_dict, scaled_fp8_key)
+ scaled_fp8_dtype = scaled_meta.dtype
+ scaled_fp8_weight_nelements = scaled_meta.numel
+ else:
+ scaled_fp8_weight = state_dict[scaled_fp8_key]
+ scaled_fp8_dtype = scaled_fp8_weight.dtype
+ scaled_fp8_weight_nelements = scaled_fp8_weight.nelement()
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
- if scaled_fp8_weight.nelement() == 2:
+ if scaled_fp8_weight_nelements == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
- out_sd = {}
layers = {}
- for k in list(state_dict.keys()):
- if k == scaled_fp8_key:
- continue
- if not k.startswith(model_prefix):
- out_sd[k] = state_dict[k]
- continue
- k_out = k
- w = state_dict.pop(k)
- layer = None
- if k_out.endswith(".scale_weight"):
- layer = k_out[:-len(".scale_weight")]
- k_out = "{}.weight_scale".format(layer)
-
- if layer is not None:
- layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
- if full_precision_matrix_mult:
- layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
- layers[layer] = layer_conf
-
- if k_out.endswith(".scale_input"):
- layer = k_out[:-len(".scale_input")]
- k_out = "{}.input_scale".format(layer)
- if w.item() == 1.0:
+ if is_stream_state_dict(state_dict):
+ key_map = {}
+ for k in list(state_dict.keys()):
+ if k == scaled_fp8_key:
continue
+ if not k.startswith(model_prefix):
+ key_map[k] = k
+ continue
+ k_out = k
+ layer = None
+ if k_out.endswith(".scale_weight"):
+ layer = k_out[:-len(".scale_weight")]
+ k_out = "{}.weight_scale".format(layer)
- out_sd[k_out] = w
+ if layer is not None:
+ layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
+ if full_precision_matrix_mult:
+ layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
+ layers[layer] = layer_conf
- state_dict = out_sd
+ if k_out.endswith(".scale_input"):
+ layer = k_out[:-len(".scale_input")]
+ k_out = "{}.input_scale".format(layer)
+ scale_val = state_dict[k]
+ if scale_val.item() == 1.0:
+ continue
+
+ key_map[k] = k_out
+ state_dict = safetensors_stream.MappedStateDict(state_dict, key_map)
+ else:
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ if k == scaled_fp8_key:
+ continue
+ if not k.startswith(model_prefix):
+ out_sd[k] = state_dict[k]
+ continue
+ k_out = k
+ w = state_dict.pop(k)
+ layer = None
+ if k_out.endswith(".scale_weight"):
+ layer = k_out[:-len(".scale_weight")]
+ k_out = "{}.weight_scale".format(layer)
+
+ if layer is not None:
+ layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
+ if full_precision_matrix_mult:
+ layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
+ layers[layer] = layer_conf
+
+ if k_out.endswith(".scale_input"):
+ layer = k_out[:-len(".scale_input")]
+ k_out = "{}.input_scale".format(layer)
+ if w.item() == 1.0:
+ continue
+
+ out_sd[k_out] = w
+
+ state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])
diff --git a/nodes.py b/nodes.py
index 56b74ebe3..bd4650276 100644
--- a/nodes.py
+++ b/nodes.py
@@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
-import safetensors.torch
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@@ -525,7 +524,7 @@ class LoadLatent:
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
- latent = safetensors.torch.load_file(latent_path, device="cpu")
+ latent = comfy.utils.load_torch_file(latent_path, safe_load=True)
multiplier = 1.0
if "latent_format_version_0" not in latent:
multiplier = 1.0 / 0.18215
diff --git a/requirements.txt b/requirements.txt
index bc8346bcf..9d3e0aa53 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,6 +11,7 @@ transformers>=4.50.3
tokenizers>=0.13.3
sentencepiece
safetensors>=0.4.2
+fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip
aiohttp>=3.11.8
yarl>=1.18.0
pyyaml
diff --git a/tests-unit/utils/safetensors_stream_test.py b/tests-unit/utils/safetensors_stream_test.py
new file mode 100644
index 000000000..e9bb72245
--- /dev/null
+++ b/tests-unit/utils/safetensors_stream_test.py
@@ -0,0 +1,146 @@
+import os
+
+import pytest
+import importlib
+import importlib.util
+
+torch = pytest.importorskip("torch")
+
+
+def _write_safetensors(tmp_path, tensors):
+ import safetensors.torch
+ path = os.path.join(tmp_path, "test.safetensors")
+ safetensors.torch.save_file(tensors, path)
+ return path
+
+
+def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ meta = sd.meta("a")
+ assert meta.shape == (2, 3)
+ assert meta.dtype == torch.float32
+ assert meta.numel == 6
+ assert calls == []
+
+
+def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ _ = sd["a"]
+ assert len(calls) == 1
+ assert calls[0].shape == (2, 3)
+
+
+def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ calls = []
+
+ original = sd._file.read_tensor
+
+ def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
+ calls.append(meta)
+ return original(meta, device, dtype, allow_gds, pin_if_cpu)
+
+ monkeypatch.setattr(sd._file, "read_tensor", wrapped)
+ view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
+ _ = list(view.keys())
+ assert calls == []
+
+
+def test_stream_load_rss_small(tmp_path):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ psutil = pytest.importorskip("psutil")
+ process = psutil.Process()
+ size_elems = 4_000_000 # ~16MB float32
+ tensor = torch.zeros((size_elems,), dtype=torch.float32)
+ path = _write_safetensors(tmp_path, {"big": tensor})
+ rss_before = process.memory_info().rss
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ rss_after = process.memory_info().rss
+ expected_size = tensor.numel() * tensor.element_size()
+ assert (rss_after - rss_before) < expected_size
+ _ = sd.meta("big")
+
+
+def test_gds_path_errors_without_support(tmp_path, monkeypatch):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ device = torch.device("cuda")
+
+ if importlib.util.find_spec("fastsafetensors") is None:
+ fst = None
+ else:
+ fst = importlib.import_module("fastsafetensors")
+
+ gds_available = False
+ if fst is not None and torch.cuda.is_available():
+ gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
+ gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
+
+ if not gds_available:
+ with pytest.raises(RuntimeError, match="GPUDirect requested"):
+ sd.get_tensor("a", device=device, allow_gds=True)
+ else:
+ def fail_nogds(*args, **kwargs):
+ raise AssertionError("nogds path used during GDS request")
+
+ monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
+ t = sd.get_tensor("a", device=device, allow_gds=True)
+ assert t.device.type == "cuda"
+
+
+def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
+ if torch is None:
+ pytest.skip("torch not installed")
+ import comfy.utils
+ import comfy.disk_weights
+
+ prev_cache = comfy.disk_weights.CACHE.max_bytes
+ prev_gds = comfy.disk_weights.ALLOW_GDS
+ prev_pin = comfy.disk_weights.PIN_IF_CPU
+ prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
+ comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
+
+ try:
+ path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
+ sd = comfy.utils.load_torch_file(path, safe_load=True)
+ model = torch.nn.Linear(4, 4, bias=True)
+ comfy.utils.load_state_dict(model, sd, strict=False)
+ assert model.weight.device.type == "cpu"
+ assert model.weight.device.type != "meta"
+ finally:
+ comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)