From f925f8fa774196e2271b6a253964af255c0146ee Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Thu, 8 Jan 2026 13:02:08 +0200 Subject: [PATCH] Add streaming safetensors loading with disk weight cache --- DESIGN.md | 57 ++ README.md | 8 + comfy/audio_encoders/audio_encoders.py | 2 +- comfy/cli_args.py | 3 + comfy/clip_vision.py | 2 +- comfy/controlnet.py | 12 +- comfy/disk_weights.py | 275 +++++++ comfy/gligen.py | 5 +- comfy/ldm/hunyuan_video/upsampler.py | 3 +- comfy/ldm/lightricks/vae/audio_vae.py | 5 +- comfy/ldm/mmaudio/vae/vae.py | 4 +- comfy/model_base.py | 14 +- comfy/model_detection.py | 165 ++-- comfy/model_management.py | 13 + comfy/model_patcher.py | 4 +- comfy/safetensors_stream.py | 799 ++++++++++++++++++++ comfy/sd.py | 129 ++-- comfy/sd1_clip.py | 5 +- comfy/taesd/taesd.py | 4 +- comfy/text_encoders/lt.py | 2 +- comfy/utils.py | 195 ++++- nodes.py | 3 +- requirements.txt | 1 + tests-unit/utils/safetensors_stream_test.py | 146 ++++ 24 files changed, 1640 insertions(+), 216 deletions(-) create mode 100644 DESIGN.md create mode 100644 comfy/disk_weights.py create mode 100644 comfy/safetensors_stream.py create mode 100644 tests-unit/utils/safetensors_stream_test.py diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 000000000..8e74b256d --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,57 @@ +# Disk tier safetensors streaming design audit (ComfyUI) + +## Mandatory research audit (verified call sites) + +### ComfyUI load path + eager materialization sites +- `comfy/utils.py:load_torch_file` currently uses `safetensors.safe_open` and iterates all keys to build a full `sd` dict (eager tensor materialization). It also returns metadata only after reading all tensors.【F:comfy/utils.py†L58-L93】 +- `comfy/utils.py:calculate_parameters` and `weight_dtype` iterate `sd.keys()` and then access `sd[k]` to compute `nelement()`/`dtype` (loads tensors).【F:comfy/utils.py†L109-L128】 +- `comfy/utils.py:state_dict_prefix_replace` mutates dicts by `pop`+assignment (materializes if used on a streaming mapping).【F:comfy/utils.py†L135-L144】 +- `comfy/model_base.py:BaseModel.load_model_weights` builds `to_load = {}` by iterating keys and popping tensors, then passes a fully materialized dict to `load_state_dict` (RAM spike).【F:comfy/model_base.py†L301-L318】 +- `comfy/model_detection.py` reads `state_dict[key].shape` in many branches for detection (must be metadata-only). Example: `calculate_transformer_depth` and numerous `detect_unet_config` branches read shapes directly from `state_dict` values.【F:comfy/model_detection.py†L21-L200】 +- `comfy/sd.py` loads checkpoints, then slices, renames, and computes parameters/dtypes by reading tensors (e.g., `calculate_parameters`, `weight_dtype`, `process_*_state_dict`, and special scaled-FP8 conversion that builds new dicts).【F:comfy/sd.py†L1304-L1519】 +- Direct safetensors load outside `load_torch_file`: `comfy/sd1_clip.py:load_embed` and `nodes.py:LoadLatent.load` use `safetensors.torch.load_file`, bypassing the core loader.【F:comfy/sd1_clip.py†L432-L434】【F:nodes.py†L521-L529】 + +### FastSageTensors (fastsafetensors) capability audit +- Header parsing and metadata: + - `fastsafetensors/common.py:SafeTensorsMetadata` parses the header and builds per-tensor `TensorFrame` with `dtype`, `shape`, and `data_offsets` (no tensor allocation).【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L63-L187】 + - `TensorFrame` stores dtype/shape/offsets and supports slicing metadata.【F:../third_party/fastsafetensors-main/fastsafetensors/common.py†L238-L338】 +- GDS + no-GDS low-level readers: + - `fastsafetensors/cpp.pyi` exposes `gds_file_reader`, `gds_file_handle`, `nogds_file_reader`, `cpu_malloc`, `gpu_malloc`, and alignment helpers such as `get_alignment_size()`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L1-L43】 + - GDS availability checks are in `fastsafetensors/cpp.pyi`: `is_gds_supported`, `is_cufile_found`, `cufile_version`, and `init_gds`.【F:../third_party/fastsafetensors-main/fastsafetensors/cpp.pyi†L36-L43】 +- DLPack wrapping: + - `fastsafetensors/dlpack.py` provides `from_cuda_buffer()` which creates DLPack capsules for both CPU and GPU buffers via a device descriptor and is used for `torch.from_dlpack`.【F:../third_party/fastsafetensors-main/fastsafetensors/dlpack.py†L232-L239】 +- Torch framework interop: + - `fastsafetensors/frameworks/_torch.py:TorchOp` provides `alloc_tensor_memory`/`free_tensor_memory`, dtype mapping, and uses `torch.from_dlpack` for wrapping raw pointers into tensors.【F:../third_party/fastsafetensors-main/fastsafetensors/frameworks/_torch.py†L131-L205】 + +### VRAM/RAM offload logic (for extension) +- `comfy/model_management.py` handles VRAM/RAM offload via `free_memory` and keeps tracking of loaded/offloaded memory (needs integration for RAM disk tier).【F:comfy/model_management.py†L584-L612】 +- `comfy/model_patcher.py` implements module-by-module offload/low-vram weight casting (`comfy_cast_weights`) and partial unload/load (needs to integrate disk tier for RAM eviction).【F:comfy/model_patcher.py†L663-L955】 + +## Strategy summary (no coding performed yet) + +### Streaming safetensors mapping (no full dict materialization) +- Introduce a new module `comfy/safetensors_stream.py` with: + - `TensorMeta` and `SafeTensorIndex` (metadata-only parsing with `fastsafetensors.SafeTensorsMetadata`). + - `StreamStateDict` as a mapping backed by `SafeTensorIndex`, exposing metadata-only `keys()`/`__iter__` and loading tensors on demand. + - Lightweight mapping views: `PrefixViewStateDict`, `FilterViewStateDict`, `RenameViewStateDict` for lazy prefix/filter/rename without eager loading. + +### Range reads and tiering +- Disk→RAM: use `fastsafetensors.cpp.nogds_file_reader` for range reads and wrap with DLPack. +- Disk→GPU (GDS): use `gds_file_reader` + `gds_file_handle` to read the aligned range directly into GPU memory. If GDS is requested but not supported (e.g., `is_gds_supported==0` or libcufile missing), raise a hard error with instructions to disable GDS. +- Disk→RAM→GPU: read only the tensor range into (optionally pinned) CPU memory, copy to GPU, then release CPU buffer unless RAM cache policy keeps it. + +### Disk tier integration +- Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`. +- Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload. +- Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`. + +### Pipeline refactors +- Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading. +- Update helpers (`calculate_parameters`, `weight_dtype`, `state_dict_prefix_replace`) to be metadata-aware and lazy. +- Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead. +- Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads). +- Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader. + +### Tests and docs +- Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path. +- Document new flags for RAM cache size and GPUDirect enablement and how to disable GDS when unsupported. diff --git a/README.md b/README.md index 6d09758c0..550482e5c 100644 --- a/README.md +++ b/README.md @@ -349,6 +349,14 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step | `--enable-manager` | Enable ComfyUI-Manager | | `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | | `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | +| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. | +| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. | + +### Disk tier for model weights + +When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand. + +If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead. # Running diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..bf3abf834 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -29,7 +29,7 @@ class AudioEncoderModel(): self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.model, sd, strict=False) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dae9a895d..d75b9fe99 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -114,6 +114,9 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") +parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.") +parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index d5fc53497..837eba013 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -48,7 +48,7 @@ class ClipVisionModel(): self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.model, sd, strict=False) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..890076938 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -439,7 +439,7 @@ def controlnet_config(sd, model_options={}): return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device def controlnet_load_state_dict(control_model, sd): - missing, unexpected = control_model.load_state_dict(sd, strict=False) + missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) @@ -473,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}): class ControlNetSD35(ControlNet): def pre_run(self, model, percent_to_timestep_function): if self.control_model.double_y_emb: - missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False) + missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False) else: - missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False) + missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False) super().pre_run(model, percent_to_timestep_function) def copy(self): @@ -748,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): pass w = WeightsLoader() w.control_model = control_model - missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False) else: - missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) @@ -874,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options else: return None - missing, unexpected = model_ad.load_state_dict(t2i_data) + missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True) if len(missing) > 0: logging.warning("t2i missing {}".format(missing)) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py new file mode 100644 index 000000000..6610733d6 --- /dev/null +++ b/comfy/disk_weights.py @@ -0,0 +1,275 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from __future__ import annotations + +import collections +import weakref +from dataclasses import dataclass +from typing import Dict, Optional + +import torch + + +ALLOW_GDS = False +PIN_IF_CPU = False +DISK_WEIGHTS_ENABLED = False + + +@dataclass +class DiskTensorRef: + state_dict: object + key: str + meta: object + requires_grad: bool + is_buffer: bool + + def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor: + dtype = getattr(self.meta, "dtype", None) + if hasattr(self.state_dict, "get_tensor"): + return self.state_dict.get_tensor( + self.key, + device=device, + dtype=dtype, + allow_gds=allow_gds, + pin_if_cpu=pin_if_cpu, + ) + tensor = self.state_dict[self.key] + if device is not None and tensor.device != device: + tensor = tensor.to(device=device) + if dtype is not None and tensor.dtype != dtype: + tensor = tensor.to(dtype=dtype) + return tensor + + +class DiskWeightRegistry: + def __init__(self): + self._registry = weakref.WeakKeyDictionary() + + def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef): + module_refs = self._registry.setdefault(module, {}) + module_refs[name] = ref + + def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]: + return self._registry.get(module) + + def has(self, module: torch.nn.Module) -> bool: + return module in self._registry + + +@dataclass +class CacheEntry: + module_ref: weakref.ReferenceType + name: str + size_bytes: int + is_buffer: bool + + +class DiskWeightCache: + def __init__(self, max_bytes: int = 0): + self.max_bytes = max_bytes + self.current_bytes = 0 + self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict() + + def set_limit(self, max_bytes: int): + self.max_bytes = max_bytes + self._evict_if_needed() + + def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]: + return (id(module), name) + + def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool): + if tensor.device.type != "cpu": + return + size_bytes = tensor.numel() * tensor.element_size() + key = self._entry_key(module, name) + if key in self._entries: + entry = self._entries.pop(key) + self.current_bytes -= entry.size_bytes + module_ref = weakref.ref(module, self._drop_module_entries) + self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer) + self.current_bytes += size_bytes + self._evict_if_needed() + + def touch(self, module: torch.nn.Module, name: str): + key = self._entry_key(module, name) + if key in self._entries: + entry = self._entries.pop(key) + self._entries[key] = entry + + def evict_bytes(self, bytes_to_free: int): + freed = 0 + while self._entries and freed < bytes_to_free: + _, entry = self._entries.popitem(last=False) + freed += entry.size_bytes + self.current_bytes -= entry.size_bytes + module = entry.module_ref() + if module is not None: + _evict_module_weight(module, entry.name, entry.is_buffer) + return freed + + def _drop_module_entries(self, module_ref: weakref.ReferenceType): + to_remove = [] + for key, entry in self._entries.items(): + if entry.module_ref is module_ref: + to_remove.append(key) + for key in to_remove: + entry = self._entries.pop(key) + self.current_bytes -= entry.size_bytes + + def _evict_if_needed(self): + while self._entries and self.current_bytes > self.max_bytes: + _, entry = self._entries.popitem(last=False) + self.current_bytes -= entry.size_bytes + module = entry.module_ref() + if module is not None: + _evict_module_weight(module, entry.name, entry.is_buffer) + + +REGISTRY = DiskWeightRegistry() +CACHE = DiskWeightCache(0) + + +def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True): + global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED + ALLOW_GDS = allow_gds + PIN_IF_CPU = pin_if_cpu + DISK_WEIGHTS_ENABLED = enabled + CACHE.set_limit(cache_bytes if enabled else 0) + if not enabled: + CACHE._entries.clear() + CACHE.current_bytes = 0 + + +def disk_weights_enabled() -> bool: + return DISK_WEIGHTS_ENABLED + + +def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""): + if not disk_weights_enabled(): + return + if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"): + return + for name, param in module.named_parameters(recurse=True): + key = f"{prefix}{name}" if prefix else name + if key in state_dict: + meta = state_dict.meta(key) + ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False) + REGISTRY.register(module, name, ref) + if param.device.type == "cpu": + CACHE.record(module, name, param, is_buffer=False) + for name, buf in module.named_buffers(recurse=True): + key = f"{prefix}{name}" if prefix else name + if key in state_dict and buf is not None: + meta = state_dict.meta(key) + ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True) + REGISTRY.register(module, name, ref) + if buf.device.type == "cpu": + CACHE.record(module, name, buf, is_buffer=True) + + +def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): + ref = REGISTRY.get(module) + if not ref or name not in ref: + return + disk_ref = ref[name] + shape = getattr(disk_ref.meta, "shape", None) + dtype = getattr(disk_ref.meta, "dtype", None) + if shape is None or dtype is None: + return + meta_tensor = torch.empty(shape, dtype=dtype, device="meta") + if is_buffer: + module._buffers[name] = meta_tensor + else: + module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + + +def _find_tensor_device(args, kwargs) -> Optional[torch.device]: + def check(obj): + if torch.is_tensor(obj): + return obj.device + if isinstance(obj, (list, tuple)): + for item in obj: + dev = check(item) + if dev is not None: + return dev + if isinstance(obj, dict): + for item in obj.values(): + dev = check(item) + if dev is not None: + return dev + return None + + dev = check(args) + if dev is not None: + return dev + return check(kwargs) + + +def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device): + refs = REGISTRY.get(module) + if not refs: + return + for name, disk_ref in refs.items(): + if name in module._parameters: + current = module._parameters[name] + is_buffer = False + elif name in module._buffers: + current = module._buffers[name] + is_buffer = True + else: + continue + if current is None: + continue + if current.device.type != "meta": + if current.device.type == "cpu": + CACHE.touch(module, name) + continue + tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU) + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) + if tensor.device.type == "cpu": + CACHE.record(module, name, tensor, is_buffer=is_buffer) + + +def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): + if not REGISTRY.has(module): + return + if getattr(module, "comfy_cast_weights", False): + target_device = torch.device("cpu") + else: + target_device = _find_tensor_device(args, kwargs) or torch.device("cpu") + ensure_module_materialized(module, target_device) + + +def attach_disk_weight_hooks(model: torch.nn.Module): + if not disk_weights_enabled(): + return + for module in model.modules(): + if getattr(module, "_disk_weight_hook_attached", False): + continue + module.register_forward_pre_hook(disk_weight_pre_hook) + module._disk_weight_hook_attached = True + + +def evict_ram_cache(bytes_to_free: int): + if bytes_to_free <= 0: + return 0 + return CACHE.evict_bytes(bytes_to_free) diff --git a/comfy/gligen.py b/comfy/gligen.py index 1d7b6c2f4..c2cf7c6db 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -3,6 +3,7 @@ import torch from torch import nn from .ldm.modules.attention import CrossAttention, FeedForward import comfy.ops +import comfy.utils ops = comfy.ops.manual_cast @@ -282,7 +283,7 @@ def load_gligen(sd): gated = GatedSelfAttentionDense( query_dim, key_dim, n_heads, d_head) - gated.load_state_dict(n_sd, strict=False) + comfy.utils.load_state_dict(gated, n_sd, strict=False) output_list.append(gated) if "position_net.null_positive_feature" in sd_k: @@ -293,7 +294,7 @@ def load_gligen(sd): pass w = WeightsLoader() w.position_net = PositionNet(in_dim, out_dim) - w.load_state_dict(sd, strict=False) + comfy.utils.load_state_dict(w, sd, strict=False) gligen = Gligen(output_list, w.position_net, key_dim) return gligen diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index d9e76922f..0c8b9240f 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,4 +1,5 @@ import torch +import comfy.utils import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d @@ -112,7 +113,7 @@ class HunyuanVideo15SRModel(): self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return comfy.utils.load_state_dict(self.model, sd, strict=True) def get_sd(self): return self.model.state_dict() diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index a9111d3bd..c61883ba3 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -2,6 +2,7 @@ import json from dataclasses import dataclass import math import torch +import comfy.utils import torchaudio import comfy.model_management @@ -153,8 +154,8 @@ class AudioVAE(torch.nn.Module): self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) + comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False) + comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False) autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py index 62f24606c..1009d2d76 100644 --- a/comfy/ldm/mmaudio/vae/vae.py +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import comfy.utils import torch.nn as nn from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, @@ -152,7 +153,7 @@ class VAE(nn.Module): return dec, posterior def load_weights(self, src_dict) -> None: - self.load_state_dict(src_dict, strict=True) + comfy.utils.load_state_dict(self, src_dict, strict=True) @property def device(self) -> torch.device: @@ -355,4 +356,3 @@ def get_my_vae(name: str, **kwargs) -> VAE: if name == '44k': return VAE_44k(**kwargs) raise ValueError(f'Unknown model: {name}') - diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..3bb155f2c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -299,20 +299,14 @@ class BaseModel(torch.nn.Module): return out def load_model_weights(self, sd, unet_prefix=""): - to_load = {} - keys = list(sd.keys()) - for k in keys: - if k.startswith(unet_prefix): - to_load[k[len(unet_prefix):]] = sd.pop(k) - + to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) if len(u) > 0: logging.warning("unet unexpected: {}".format(u)) - del to_load return self def process_latent_in(self, latent): @@ -751,8 +745,8 @@ class StableAudio1(BaseModel): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) - self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights) - self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) + utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True) + utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True) def extra_conds(self, **kwargs): out = {} diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0853b3aec..0b3a05659 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -19,6 +19,9 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def sd_shape(state_dict, key): + return comfy.utils.state_dict_meta(state_dict, key).shape + def calculate_transformer_depth(prefix, state_dict_keys, state_dict): context_dim = None use_linear_in_transformer = False @@ -27,8 +30,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) if len(transformer_keys) > 0: last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') - context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] - use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + context_dim = sd_shape(state_dict, '{}0.attn2.to_k.weight'.format(transformer_prefix))[1] + use_linear_in_transformer = len(sd_shape(state_dict, '{}1.proj_in.weight'.format(prefix))) == 2 time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict time_stack_cross = '{}1.time_stack.0.attn2.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn2.to_q.weight'.format(prefix) in state_dict return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross @@ -39,27 +42,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model unet_config = {} - unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1] - patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2] + unet_config["in_channels"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[1] + patch_size = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[2] unet_config["patch_size"] = patch_size final_layer = '{}final_layer.linear.weight'.format(key_prefix) if final_layer in state_dict: - unet_config["out_channels"] = state_dict[final_layer].shape[0] // (patch_size * patch_size) + unet_config["out_channels"] = sd_shape(state_dict, final_layer)[0] // (patch_size * patch_size) - unet_config["depth"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] // 64 + unet_config["depth"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] // 64 unet_config["input_size"] = None y_key = '{}y_embedder.mlp.0.weight'.format(key_prefix) if y_key in state_dict_keys: - unet_config["adm_in_channels"] = state_dict[y_key].shape[1] + unet_config["adm_in_channels"] = sd_shape(state_dict, y_key)[1] context_key = '{}context_embedder.weight'.format(key_prefix) if context_key in state_dict_keys: - in_features = state_dict[context_key].shape[1] - out_features = state_dict[context_key].shape[0] + in_features = sd_shape(state_dict, context_key)[1] + out_features = sd_shape(state_dict, context_key)[0] unet_config["context_embedder_config"] = {"target": "torch.nn.Linear", "params": {"in_features": in_features, "out_features": out_features}} num_patches_key = '{}pos_embed'.format(key_prefix) if num_patches_key in state_dict_keys: - num_patches = state_dict[num_patches_key].shape[1] + num_patches = sd_shape(state_dict, num_patches_key)[1] unet_config["num_patches"] = num_patches unet_config["pos_embed_max_size"] = round(math.sqrt(num_patches)) @@ -83,23 +86,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) if text_mapper_name in state_dict_keys: unet_config['stable_cascade_stage'] = 'c' - w = state_dict[text_mapper_name] - if w.shape[0] == 1536: #stage c lite + w_shape = sd_shape(state_dict, text_mapper_name) + if w_shape[0] == 1536: #stage c lite unet_config['c_cond'] = 1536 unet_config['c_hidden'] = [1536, 1536] unet_config['nhead'] = [24, 24] unet_config['blocks'] = [[4, 12], [12, 4]] - elif w.shape[0] == 2048: #stage c full + elif w_shape[0] == 2048: #stage c full unet_config['c_cond'] = 2048 elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: unet_config['stable_cascade_stage'] = 'b' - w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] - if w.shape[-1] == 640: + w_shape = sd_shape(state_dict, '{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)) + if w_shape[-1] == 640: unet_config['c_hidden'] = [320, 640, 1280, 1280] unet_config['nhead'] = [-1, -1, 20, 20] unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] - elif w.shape[-1] == 576: #stage b lite + elif w_shape[-1] == 576: #stage b lite unet_config['c_hidden'] = [320, 576, 1152, 1152] unet_config['nhead'] = [-1, 9, 18, 18] unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] @@ -113,8 +116,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit unet_config = {} - unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] - unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] + unet_config["max_seq"] = sd_shape(state_dict, '{}positional_encoding'.format(key_prefix))[1] + unet_config["cond_seq_dim"] = sd_shape(state_dict, '{}cond_seq_linear.weight'.format(key_prefix))[1] double_layers = count_blocks(state_dict_keys, '{}double_layers.'.format(key_prefix) + '{}.') single_layers = count_blocks(state_dict_keys, '{}single_layers.'.format(key_prefix) + '{}.') unet_config["n_double_layers"] = double_layers @@ -125,10 +128,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config = {} unet_config["image_model"] = "hydit" unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') - unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] + unet_config["hidden_size"] = sd_shape(state_dict, '{}x_embedder.proj.weight'.format(key_prefix))[0] if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2 unet_config["mlp_ratio"] = 4.3637 - if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968: + if sd_shape(state_dict, '{}extra_embedder.0.weight'.format(key_prefix))[1] == 3968: unet_config["size_cond"] = True unet_config["use_style_cond"] = True unet_config["image_model"] = "hydit1" @@ -136,12 +139,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} - in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)] - out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)] + in_w_shape = sd_shape(state_dict, '{}img_in.proj.weight'.format(key_prefix)) + out_w_shape = sd_shape(state_dict, '{}final_layer.linear.weight'.format(key_prefix)) dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels - dit_config["patch_size"] = list(in_w.shape[2:]) - dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) + dit_config["in_channels"] = in_w_shape[1] #SkyReels img2video has 32 input channels + dit_config["patch_size"] = list(in_w_shape[2:]) + dit_config["out_channels"] = out_w_shape[0] // math.prod(dit_config["patch_size"]) if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys): dit_config["vec_in_dim"] = 768 else: @@ -157,10 +160,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["meanflow"] = False - dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] - dit_config["hidden_size"] = in_w.shape[0] + dit_config["context_in_dim"] = sd_shape(state_dict, '{}txt_in.input_embedder.weight'.format(key_prefix))[1] + dit_config["hidden_size"] = in_w_shape[0] dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = in_w.shape[0] // 128 + dit_config["num_heads"] = in_w_shape[0] // 128 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["theta"] = 256 @@ -179,7 +182,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["use_cond_type_embedding"] = False if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: - dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + dit_config["vision_in_dim"] = sd_shape(state_dict, '{}vision_in.proj.0.weight'.format(key_prefix))[0] dit_config["meanflow_sum"] = True else: dit_config["vision_in_dim"] = None @@ -221,19 +224,19 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - w = state_dict[in_key] - dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) - dit_config["hidden_size"] = w.shape[0] + w_shape = sd_shape(state_dict, in_key) + dit_config["in_channels"] = w_shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w_shape[0] txt_in_key = "{}txt_in.weight".format(key_prefix) if txt_in_key in state_dict_keys: - w = state_dict[txt_in_key] - dit_config["context_in_dim"] = w.shape[1] - dit_config["hidden_size"] = w.shape[0] + w_shape = sd_shape(state_dict, txt_in_key) + dit_config["context_in_dim"] = w_shape[1] + dit_config["hidden_size"] = w_shape[0] vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: - dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] + dit_config["vec_in_dim"] = sd_shape(state_dict, vec_in_key)[1] else: dit_config["vec_in_dim"] = None @@ -307,7 +310,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config = {} dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') - shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape + shape = sd_shape(state_dict, '{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)) dit_config["attention_head_dim"] = shape[0] // 32 dit_config["cross_attention_dim"] = shape[1] if metadata is not None and "config" in metadata: @@ -350,11 +353,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): y_key = "{}y_embedder.y_embedding".format(key_prefix) if y_key in state_dict_keys: - dit_config["model_max_length"] = state_dict[y_key].shape[0] + dit_config["model_max_length"] = sd_shape(state_dict, y_key)[0] pe_key = "{}pos_embed".format(key_prefix) if pe_key in state_dict_keys: - dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size + dit_config["input_size"] = int(math.sqrt(sd_shape(state_dict, pe_key)[1])) * patch_size dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix) @@ -373,11 +376,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 concat_padding_mask = True - dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 - dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0] + dit_config["model_channels"] = sd_shape(state_dict, '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix))[0] dit_config["block_config"] = "FA-CA-MLP" dit_config["concat_padding_mask"] = concat_padding_mask dit_config["pos_emb_cls"] = "rope3d" @@ -416,9 +419,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] - dit_config["dim"] = w.shape[0] - dit_config["cap_feat_dim"] = w.shape[1] + w_shape = sd_shape(state_dict, '{}cap_embedder.1.weight'.format(key_prefix)) + dit_config["dim"] = w_shape[0] + dit_config["cap_feat_dim"] = w_shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["qk_norm"] = True @@ -429,9 +432,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["axes_lens"] = [300, 512, 512] dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 - ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: # NewBie - dit_config["clip_text_dim"] = ctd_weight.shape[0] + ctd_key = '{}clip_text_pooled_proj.0.weight'.format(key_prefix) + if ctd_key in state_dict_keys: # NewBie + dit_config["clip_text_dim"] = sd_shape(state_dict, ctd_key)[0] # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 @@ -450,12 +453,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" - dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1] - out_dim = state_dict['{}head.head.weight'.format(key_prefix)].shape[0] // 4 + dim = sd_shape(state_dict, '{}head.modulation'.format(key_prefix))[-1] + out_dim = sd_shape(state_dict, '{}head.head.weight'.format(key_prefix))[0] // 4 dit_config["dim"] = dim dit_config["out_dim"] = out_dim dit_config["num_heads"] = dim // 128 - dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0] + dit_config["ffn_dim"] = sd_shape(state_dict, '{}blocks.0.ffn.0.weight'.format(key_prefix))[0] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') dit_config["patch_size"] = (1, 2, 2) dit_config["freq_dim"] = 256 @@ -463,10 +466,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["qk_norm"] = True dit_config["cross_attn_norm"] = True dit_config["eps"] = 1e-6 - dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["in_dim"] = sd_shape(state_dict, '{}patch_embedding.weight'.format(key_prefix))[1] if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "vace" - dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] + dit_config["vace_in_dim"] = sd_shape(state_dict, '{}vace_patch_embedding.weight'.format(key_prefix))[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: @@ -484,22 +487,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" - flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) - if flf_weight is not None: - dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + flf_key = '{}img_emb.emb_pos'.format(key_prefix) + if flf_key in state_dict_keys: + dit_config["flf_pos_embed_token_number"] = sd_shape(state_dict, flf_key)[1] - ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) - if ref_conv_weight is not None: - dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + ref_conv_key = '{}ref_conv.weight'.format(key_prefix) + if ref_conv_key in state_dict_keys: + dit_config["in_dim_ref_conv"] = sd_shape(state_dict, ref_conv_key)[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D - in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + in_shape = sd_shape(state_dict, '{}latent_in.weight'.format(key_prefix)) dit_config = {} dit_config["image_model"] = "hunyuan3d2" dit_config["in_channels"] = in_shape[1] - dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["context_in_dim"] = sd_shape(state_dict, '{}cond_in.weight'.format(key_prefix))[1] dit_config["hidden_size"] = in_shape[0] dit_config["mlp_ratio"] = 4.0 dit_config["num_heads"] = 16 @@ -513,9 +516,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" - dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] + dit_config["in_channels"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[1] dit_config["context_dim"] = 1024 - dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0] + dit_config["hidden_size"] = sd_shape(state_dict, f"{key_prefix}x_embedder.weight")[0] dit_config["mlp_ratio"] = 4.0 dit_config["num_heads"] = 16 dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}") @@ -549,11 +552,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 concat_padding_mask = True - dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["in_channels"] = (sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[1] // 4) - int(concat_padding_mask) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 - dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["model_channels"] = sd_shape(state_dict, '{}x_embedder.proj.1.weight'.format(key_prefix))[0] dit_config["concat_padding_mask"] = concat_padding_mask dit_config["crossattn_emb_channels"] = 1024 dit_config["pos_emb_cls"] = "rope3d" @@ -617,7 +620,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" - dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["in_channels"] = sd_shape(state_dict, '{}img_in.weight'.format(key_prefix))[1] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" @@ -628,7 +631,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 dit_config = {} - model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + model_dim = sd_shape(state_dict, '{}visual_embeddings.in_layer.bias'.format(key_prefix))[0] dit_config["model_dim"] = model_dim if model_dim in [4096, 2560]: # pro video and lite image dit_config["axes_dims"] = (32, 48, 48) @@ -636,10 +639,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) elif model_dim == 1792: # lite video dit_config["axes_dims"] = (16, 24, 24) - dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["time_dim"] = sd_shape(state_dict, '{}time_embeddings.in_layer.bias'.format(key_prefix))[0] dit_config["image_model"] = "kandinsky5" - dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] - dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["ff_dim"] = sd_shape(state_dict, '{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix))[0] + dit_config["visual_embed_dim"] = sd_shape(state_dict, '{}visual_embeddings.in_layer.weight'.format(key_prefix))[1] dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') return dit_config @@ -657,16 +660,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): y_input = '{}label_emb.0.0.weight'.format(key_prefix) if y_input in state_dict_keys: unet_config["num_classes"] = "sequential" - unet_config["adm_in_channels"] = state_dict[y_input].shape[1] + unet_config["adm_in_channels"] = sd_shape(state_dict, y_input)[1] else: unet_config["adm_in_channels"] = None - model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] - in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + model_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[0] + in_channels = sd_shape(state_dict, '{}input_blocks.0.0.weight'.format(key_prefix))[1] out_key = '{}out.2.weight'.format(key_prefix) if out_key in state_dict: - out_channels = state_dict[out_key].shape[0] + out_channels = sd_shape(state_dict, out_key)[0] else: out_channels = 4 @@ -713,7 +716,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): res_block_prefix = "{}0.in_layers.0.weight".format(prefix) if res_block_prefix in block_keys: last_res_blocks += 1 - last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels + last_channel_mult = sd_shape(state_dict, "{}0.out_layers.3.weight".format(prefix))[0] // model_channels out = calculate_transformer_depth(prefix, state_dict_keys, state_dict) if out is not None: @@ -867,7 +870,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_depth.append(transformer_count) if transformer_count > 0: - match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1] + match["context_dim"] = sd_shape(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab))[1] attn_res *= 2 if attn_blocks == 0: @@ -876,13 +879,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): match["transformer_depth"] = transformer_depth - match["model_channels"] = state_dict["conv_in.weight"].shape[0] - match["in_channels"] = state_dict["conv_in.weight"].shape[1] + match["model_channels"] = sd_shape(state_dict, "conv_in.weight")[0] + match["in_channels"] = sd_shape(state_dict, "conv_in.weight")[1] match["adm_in_channels"] = None if "class_embedding.linear_1.weight" in state_dict: - match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1] + match["adm_in_channels"] = sd_shape(state_dict, "class_embedding.linear_1.weight")[1] elif "add_embedding.linear_1.weight" in state_dict: - match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] + match["adm_in_channels"] = sd_shape(state_dict, "add_embedding.linear_1.weight")[1] SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, @@ -1023,11 +1026,11 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') - hidden_size = state_dict["x_embedder.bias"].shape[0] + hidden_size = sd_shape(state_dict, "x_embedder.bias")[0] sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict and 'pos_embed.proj.weight' in state_dict: #SD3 num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') - depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 + depth = sd_shape(state_dict, "pos_embed.proj.weight")[0] // 64 sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) else: return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 928282092..f389ad857 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -27,6 +27,7 @@ import platform import weakref import gc import os +import comfy.disk_weights class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -583,6 +584,8 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() + if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled(): + comfy.disk_weights.evict_ram_cache(memory_required) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -1124,6 +1127,16 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) +WEIGHTS_RAM_CACHE_BYTES = 0 +WEIGHTS_GDS_ENABLED = bool(args.weights_gds) +if args.weights_ram_cache_gb is not None: + WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3)) + comfy.disk_weights.configure( + WEIGHTS_RAM_CACHE_BYTES, + allow_gds=WEIGHTS_GDS_ENABLED, + pin_if_cpu=not args.disable_pinned_memory, + ) + PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..0e38629bc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -34,6 +34,7 @@ import comfy.lora import comfy.model_management import comfy.patcher_extension import comfy.utils +import comfy.disk_weights from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -269,6 +270,8 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + comfy.disk_weights.attach_disk_weight_hooks(self.model) + def model_size(self): if self.size > 0: return self.size @@ -1356,4 +1359,3 @@ class ModelPatcher: def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) - diff --git a/comfy/safetensors_stream.py b/comfy/safetensors_stream.py new file mode 100644 index 000000000..475bf6bf0 --- /dev/null +++ b/comfy/safetensors_stream.py @@ -0,0 +1,799 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +from __future__ import annotations + +import collections +import importlib +import importlib.util +import os +import threading +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Dict, Iterable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple + +import torch + + +_FST_MODULE = None +_FST_LOCK = threading.Lock() +_FST_LOADED = False +_GDS_INITIALIZED = False +_MISSING = object() + + +def _require_fastsafetensors(): + global _FST_MODULE + with _FST_LOCK: + if _FST_MODULE is None: + if importlib.util.find_spec("fastsafetensors") is None: + raise ImportError( + "fastsafetensors is required for safetensors streaming. " + "Install it with: pip install 'fastsafetensors @ https://github.com/" + "foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip'" + ) + _FST_MODULE = importlib.import_module("fastsafetensors") + return _FST_MODULE + + +def _init_fastsafetensors_lib(): + global _FST_LOADED + fst = _require_fastsafetensors() + if not _FST_LOADED: + fst.cpp.load_library_functions() + _FST_LOADED = True + return fst + + +def _init_gds(): + global _GDS_INITIALIZED + fst = _init_fastsafetensors_lib() + if not _GDS_INITIALIZED: + if fst.cpp.init_gds() != 0: + raise RuntimeError("fastsafetensors init_gds() failed") + _GDS_INITIALIZED = True + + +@dataclass(frozen=True) +class TensorMeta: + dtype: torch.dtype + shape: Tuple[int, ...] + numel: int + nbytes: int + data_offsets: Tuple[int, int] + filename: str + fst_dtype: object + strides: Tuple[int, ...] + + +class SafeTensorIndex: + def __init__(self, filename: str): + fst = _init_fastsafetensors_lib() + framework = fst.frameworks.get_framework_op("pytorch") + metadata = fst.common.SafeTensorsMetadata.from_file(filename, framework) + self._filename = filename + self._metadata = metadata + self._framework = framework + from fastsafetensors.frameworks import _torch as fst_torch + self._dtype_map = fst_torch.dtype_convert + self._tensor_meta: Dict[str, TensorMeta] = {} + for key, frame in metadata.tensors.items(): + torch_dtype = self._dtype_map.get(frame.dtype, None) + if torch_dtype is None: + raise ValueError(f"Unsupported safetensors dtype {frame.dtype} in {filename}") + numel = 1 + for s in frame.shape: + numel *= s + nbytes = numel * framework.get_dtype_size(frame.dtype) + self._tensor_meta[key] = TensorMeta( + dtype=torch_dtype, + shape=tuple(frame.shape), + numel=numel, + nbytes=nbytes, + data_offsets=(frame.data_offsets[0], frame.data_offsets[1]), + filename=filename, + fst_dtype=frame.dtype, + strides=tuple(frame.strides), + ) + + def keys(self) -> Iterable[str]: + return self._tensor_meta.keys() + + def has(self, key: str) -> bool: + return key in self._tensor_meta + + def meta(self, key: str) -> TensorMeta: + return self._tensor_meta[key] + + def metadata(self): + return self._metadata.metadata + + @property + def header_length(self) -> int: + return self._metadata.header_length + + @property + def size_bytes(self) -> int: + return self._metadata.size_bytes + + +class _SafeTensorFile: + def __init__(self, filename: str, index: SafeTensorIndex): + self.filename = filename + self.index = index + self._fd: Optional[int] = None + self._gds_handle = None + self._gds_reader = None + self._nogds_reader = None + self._refcount = 1 + + def acquire(self) -> "_SafeTensorFile": + self._refcount += 1 + return self + + def release(self): + self._refcount -= 1 + if self._refcount <= 0: + self.close() + + def close(self): + if self._fd is not None: + os.close(self._fd) + self._fd = None + self._gds_handle = None + + def _ensure_fd(self) -> int: + if self._fd is None: + self._fd = os.open(self.filename, os.O_RDONLY, 0o644) + return self._fd + + def _ensure_nogds_reader(self, use_cuda: bool): + fst = _init_fastsafetensors_lib() + if self._nogds_reader is None: + self._nogds_reader = fst.cpp.nogds_file_reader( + False, 16 * 1024, 16, use_cuda + ) + return self._nogds_reader + + def _ensure_gds_reader(self, use_cuda: bool): + fst = _init_fastsafetensors_lib() + if self._gds_reader is None: + self._gds_reader = fst.cpp.gds_file_reader(16, use_cuda) + return self._gds_reader + + def _ensure_gds_handle(self, use_cuda: bool): + if self._gds_handle is None: + fst = _init_fastsafetensors_lib() + framework = fst.frameworks.get_framework_op("pytorch") + o_direct = _get_gds_o_direct(framework) + self._gds_handle = fst.cpp.gds_file_handle(self.filename, o_direct, use_cuda) + return self._gds_handle + + def read_tensor( + self, + meta: TensorMeta, + device: torch.device, + dtype: Optional[torch.dtype], + allow_gds: bool, + pin_if_cpu: bool, + ) -> torch.Tensor: + fst = _init_fastsafetensors_lib() + framework = fst.frameworks.get_framework_op("pytorch") + device_is_cuda = device.type == "cuda" + if device_is_cuda and allow_gds: + _ensure_gds_ready(device) + tensor = self._read_tensor_gds( + fst, framework, meta, device, dtype + ) + return tensor + + cpu_tensor = self._read_tensor_nogds( + fst, framework, meta, torch.device("cpu"), dtype + ) + if device_is_cuda: + if pin_if_cpu: + cpu_tensor = cpu_tensor.pin_memory() + gpu_tensor = torch.empty_like(cpu_tensor, device=device) + gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu) + return gpu_tensor + return cpu_tensor + + def _aligned_range(self, abs_start: int, length: int) -> Tuple[int, int, int]: + fst = _init_fastsafetensors_lib() + align = fst.cpp.get_alignment_size() + aligned_offset = (abs_start // align) * align + head = abs_start - aligned_offset + aligned_length = length + head + tail = aligned_length % align + if tail: + aligned_length += align - tail + return aligned_offset, aligned_length, head + + def _read_tensor_nogds( + self, + fst, + framework, + meta: TensorMeta, + device: torch.device, + dtype: Optional[torch.dtype], + ) -> torch.Tensor: + fd = self._ensure_fd() + reader = self._ensure_nogds_reader(use_cuda=False) + abs_start = self.index.header_length + meta.data_offsets[0] + length = meta.data_offsets[1] - meta.data_offsets[0] + aligned_offset, aligned_length, head = self._aligned_range(abs_start, length) + ptr_align = framework.get_device_ptr_align() + buffer_length = aligned_length + ptr_align + buf_ptr = fst.cpp.cpu_malloc(buffer_length) + gbuf = fst.cpp.gds_device_buffer(buf_ptr, buffer_length, False) + ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align + req = reader.submit_read(fd, gbuf, aligned_offset, aligned_length, ptr_off) + if reader.wait_read(req) < 0: + fst.cpp.cpu_free(buf_ptr) + raise RuntimeError("nogds_file_reader read failed") + owner = _BufferOwner(lambda: fst.cpp.cpu_free(buf_ptr)) + tensor = _dlpack_tensor_from_buffer( + fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner + ) + if dtype is not None and dtype != tensor.dtype: + _validate_dtype_conversion(tensor.dtype, dtype) + tensor = tensor.to(dtype=dtype) + return tensor + + def _read_tensor_gds( + self, + fst, + framework, + meta: TensorMeta, + device: torch.device, + dtype: Optional[torch.dtype], + ) -> torch.Tensor: + reader = self._ensure_gds_reader(use_cuda=True) + handle = self._ensure_gds_handle(use_cuda=True) + abs_start = self.index.header_length + meta.data_offsets[0] + length = meta.data_offsets[1] - meta.data_offsets[0] + aligned_offset, aligned_length, head = self._aligned_range(abs_start, length) + ptr_align = framework.get_device_ptr_align() + buffer_length = aligned_length + ptr_align + fst_device = _fst_device_from_torch(fst, device) + gbuf = framework.alloc_tensor_memory(buffer_length, fst_device) + ptr_off = (- (gbuf.get_base_address() + head)) % ptr_align + file_length = self.index.size_bytes + req = reader.submit_read( + handle, gbuf, aligned_offset, aligned_length, ptr_off, file_length + ) + if reader.wait_read(req) < 0: + framework.free_tensor_memory(gbuf, fst_device) + raise RuntimeError("gds_file_reader read failed") + owner = _BufferOwner(lambda: framework.free_tensor_memory(gbuf, fst_device)) + tensor = _dlpack_tensor_from_buffer( + fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner + ) + if dtype is not None and dtype != tensor.dtype: + _validate_dtype_conversion(tensor.dtype, dtype) + tensor = tensor.to(dtype=dtype) + return tensor + + +def _fst_device_from_torch(fst, device: torch.device): + if device.type == "cuda" and device.index is not None: + return fst.st_types.Device.from_str(f"cuda:{device.index}") + return fst.st_types.Device.from_str(device.type) + + +class _BufferOwner: + def __init__(self, free_fn): + self._free_fn = free_fn + + def __del__(self): + try: + self._free_fn() + except Exception: + pass + + +def _dlpack_tensor_from_buffer( + fst, + framework, + ptr: int, + meta: TensorMeta, + device: torch.device, + owner: Optional[_BufferOwner], +) -> torch.Tensor: + disk_dtype = framework.as_workaround_dtype(meta.fst_dtype) + dev = _fst_device_from_torch(fst, device) + dl_tensor = fst.dlpack.from_cuda_buffer(ptr, list(meta.shape), list(meta.strides), disk_dtype, dev) + torch_tensor = framework.from_dlpack(dl_tensor, dev, disk_dtype).real_tensor + if disk_dtype != meta.fst_dtype: + torch_tensor = torch_tensor.view(meta.dtype) + if owner is not None: + torch_tensor._comfy_disk_buffer_owner = owner + return torch_tensor + + +def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype): + if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size(): + raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})") + + +def _get_gds_o_direct(framework) -> bool: + cuda_ver = framework.get_cuda_ver() + if cuda_ver and cuda_ver != "0.0": + ver_parts = cuda_ver.split("-", 1) + if len(ver_parts) == 2: + cudavers = list(map(int, ver_parts[1].split("."))) + if ver_parts[0] == "cuda": + return not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)) + return True + return True + + +def _ensure_gds_ready(device: torch.device): + fst = _init_fastsafetensors_lib() + if not fst.common.is_gpu_found(): + raise RuntimeError( + "GPUDirect requested but GPU runtime library is missing. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + gds_supported = fst.cpp.is_gds_supported(device.index if device.index is not None else 0) + if gds_supported < 0: + raise RuntimeError( + "GPUDirect requested but is_gds_supported() failed. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + if not fst.cpp.is_cufile_found(): + raise RuntimeError( + "GPUDirect requested but libcufile is missing. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + if gds_supported == 0: + raise RuntimeError( + "GPUDirect requested but GDS is unsupported on this platform. " + "Disable GPUDirect by omitting --weights-gds to use disk->RAM->GPU." + ) + _init_gds() + + +class StreamStateDict(collections.abc.MutableMapping): + is_stream_state_dict = True + + def __init__( + self, + index: SafeTensorIndex, + file: _SafeTensorFile, + device: torch.device, + allow_gds: bool = False, + ): + self._index = index + self._file = file + self._device = device + self._allow_gds = allow_gds + self._overrides: Dict[str, torch.Tensor] = {} + self._deleted: set[str] = set() + + @classmethod + def from_file(cls, filename: str, device: torch.device, allow_gds: bool = False) -> "StreamStateDict": + index = SafeTensorIndex(filename) + file = _SafeTensorFile(filename, index) + return cls(index, file, device, allow_gds=allow_gds) + + def close(self): + if self._file is not None: + self._file.release() + self._file = None + + def __del__(self): + try: + self.close() + except Exception: + pass + + def meta(self, key: str) -> TensorMeta: + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return TensorMeta( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + data_offsets=(0, numel * t.element_size()), + filename="", + fst_dtype=None, + strides=tuple(t.stride()), + ) + if key in self._deleted: + raise KeyError(key) + return self._index.meta(key) + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + if key in self._overrides: + t = self._overrides[key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + _validate_dtype_conversion(t.dtype, dtype) + t = t.to(dtype=dtype) + return t + if key in self._deleted: + raise KeyError(key) + if device is None: + device = self._device + if allow_gds is None: + allow_gds = self._allow_gds + meta = self._index.meta(key) + return self._file.read_tensor(meta, device, dtype, allow_gds, pin_if_cpu) + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._deleted: + raise KeyError(key) + if self._index.has(key): + self._deleted.add(key) + return + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + for k in self._index.keys(): + if k in self._deleted: + continue + if k in self._overrides: + continue + yield k + for k in self._overrides.keys(): + yield k + + def __len__(self) -> int: + base = len(self._index.keys()) + return base - len(self._deleted) + len(self._overrides) + + def __contains__(self, key: object) -> bool: + if not isinstance(key, str): + return False + if key in self._deleted: + return False + if key in self._overrides: + return True + return self._index.has(key) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + if key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if self._index.has(key): + self._deleted.add(key) + return self.get_tensor(key) + if default is _MISSING: + raise KeyError(key) + return default + + def copy(self) -> "StreamStateDict": + new = StreamStateDict(self._index, self._file.acquire(), self._device, allow_gds=self._allow_gds) + new._overrides = dict(self._overrides) + new._deleted = set(self._deleted) + return new + + def metadata(self): + return self._index.metadata() + + +class _BaseViewStateDict(MutableMapping): + is_stream_state_dict = True + + def __init__(self, base: MutableMapping, mutate_base: bool = False): + self._base = base + self._mutate_base = mutate_base + self._overrides: Dict[str, torch.Tensor] = {} + self._deleted: set[str] = set() + + def _resolve_base_key(self, key: str) -> Optional[str]: + return key + + def _iter_base_keys(self) -> Iterable[str]: + return self._base.keys() + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + if key in self._overrides: + t = self._overrides[key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + _validate_dtype_conversion(t.dtype, dtype) + t = t.to(dtype=dtype) + return t + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if hasattr(self._base, "get_tensor"): + return self._base.get_tensor( + base_key, device=device, dtype=dtype, allow_gds=allow_gds, pin_if_cpu=pin_if_cpu + ) + t = self._base[base_key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + _validate_dtype_conversion(t.dtype, dtype) + t = t.to(dtype=dtype) + return t + + def meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if hasattr(self._base, "meta"): + return self._base.meta(base_key) + t = self._base[base_key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + base_key = self._resolve_base_key(key) + if self._mutate_base and base_key is not None and base_key in self._base: + self._base[base_key] = value + else: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + raise KeyError(key) + if self._mutate_base and base_key in self._base: + del self._base[base_key] + else: + self._deleted.add(key) + + def __iter__(self) -> Iterator[str]: + for k in self._iter_base_keys(): + if k in self._deleted: + continue + yield k + for k in self._overrides.keys(): + yield k + + def __len__(self) -> int: + base_keys = list(self._iter_base_keys()) + return len(base_keys) - len(self._deleted) + len(self._overrides) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + base_key = self._resolve_base_key(key) + if base_key is None or key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if self._mutate_base: + try: + return self._base.pop(base_key) + except KeyError: + if default is _MISSING: + raise + return default + self._deleted.add(key) + return self.get_tensor(key) + + +class FilterViewStateDict(_BaseViewStateDict): + def __init__(self, base: MutableMapping, predicate, mutate_base: bool = False): + super().__init__(base, mutate_base=mutate_base) + self._predicate = predicate + + def _resolve_base_key(self, key: str) -> Optional[str]: + if self._predicate(key): + return key + return None + + def _iter_base_keys(self) -> Iterable[str]: + for k in self._base.keys(): + if self._predicate(k): + yield k + + +class PrefixViewStateDict(_BaseViewStateDict): + def __init__(self, base: MutableMapping, source_prefix: str, target_prefix: str = "", mutate_base: bool = False): + super().__init__(base, mutate_base=mutate_base) + self._source_prefix = source_prefix + self._target_prefix = target_prefix + self._mapping: Dict[str, str] = {} + self._reverse: Dict[str, str] = {} + for k in base.keys(): + if not k.startswith(source_prefix): + continue + view_key = f"{target_prefix}{k[len(source_prefix):]}" + self._mapping[k] = view_key + self._reverse[view_key] = k + + def _resolve_base_key(self, key: str) -> Optional[str]: + return self._reverse.get(key) + + def _iter_base_keys(self) -> Iterable[str]: + return self._reverse.keys() + + +class RenameViewStateDict(_BaseViewStateDict): + def __init__( + self, + base: MutableMapping, + replace_prefix: Mapping[str, str], + filter_keys: bool = False, + mutate_base: bool = False, + ): + super().__init__(base, mutate_base=mutate_base) + self._filter_keys = filter_keys + self._replace = list(replace_prefix.items()) + self._mapping: Dict[str, str] = {} + self._reverse: Dict[str, str] = {} + for k in base.keys(): + view_key = self._replace_key(k) + if view_key is None: + continue + self._mapping[k] = view_key + self._reverse[view_key] = k + + def _replace_key(self, key: str) -> Optional[str]: + for rp, dst in self._replace: + if key.startswith(rp): + return f"{dst}{key[len(rp):]}" + if self._filter_keys: + return None + return key + + def _resolve_base_key(self, key: str) -> Optional[str]: + return self._reverse.get(key) + + def _iter_base_keys(self) -> Iterable[str]: + return self._reverse.keys() + + +class MergedStateDict(MutableMapping): + is_stream_state_dict = True + + def __init__(self, *mappings: MutableMapping): + self._mappings = list(mappings) + self._overrides: Dict[str, torch.Tensor] = {} + self._deleted: set[str] = set() + + def __getitem__(self, key: str) -> torch.Tensor: + if key in self._overrides: + return self._overrides[key] + if key in self._deleted: + raise KeyError(key) + for mapping in reversed(self._mappings): + if key in mapping: + if hasattr(mapping, "get_tensor"): + return mapping.get_tensor(key) + return mapping[key] + raise KeyError(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._deleted: + raise KeyError(key) + if any(key in mapping for mapping in self._mappings): + self._deleted.add(key) + return + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + seen = set() + for mapping in self._mappings: + for key in mapping.keys(): + if key in self._deleted or key in seen: + continue + seen.add(key) + yield key + for key in self._overrides.keys(): + if key not in seen: + yield key + + def __len__(self) -> int: + return len(list(self.__iter__())) + + def meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + if key in self._deleted: + raise KeyError(key) + for mapping in reversed(self._mappings): + if key in mapping: + if hasattr(mapping, "meta"): + return mapping.meta(key) + t = mapping[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + raise KeyError(key) + + +class MappedStateDict(_BaseViewStateDict): + def __init__(self, base: MutableMapping, key_map: Mapping[str, str], mutate_base: bool = False): + super().__init__(base, mutate_base=mutate_base) + self._base_to_view = dict(key_map) + self._view_to_base = {v: k for k, v in key_map.items()} + + def _resolve_base_key(self, key: str) -> Optional[str]: + return self._view_to_base.get(key) + + def _iter_base_keys(self) -> Iterable[str]: + return self._view_to_base.keys() diff --git a/comfy/sd.py b/comfy/sd.py index 32157e18b..7b65af39d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -25,6 +25,7 @@ import math import os import comfy.utils +import comfy.safetensors_stream from . import clip_vision from . import gligen @@ -288,7 +289,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False) else: return self.cond_stage_model.load_sd(sd) @@ -346,7 +347,7 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: - self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] + self.latent_channels = sd_shape(sd, "taesd_decoder.1.weight")[1] self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade self.first_stage_model = StageA() @@ -361,25 +362,19 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - new_sd = {} - for k in sd: - new_sd["encoder.{}".format(k)] = sd[k] - sd = new_sd + sd = comfy.utils.state_dict_prefix_replace(sd, {"": "encoder."}) elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade self.first_stage_model = StageC_coder() self.latent_channels = 16 - new_sd = {} - for k in sd: - new_sd["previewer.{}".format(k)] = sd[k] - sd = new_sd + sd = comfy.utils.state_dict_prefix_replace(sd, {"": "previewer."}) elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 elif "decoder.conv_in.weight" in sd: - if sd['decoder.conv_in.weight'].shape[1] == 64: + if sd_shape(sd, 'decoder.conv_in.weight')[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1] self.downscale_ratio = 32 self.upscale_ratio = 32 self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -389,9 +384,9 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: + elif sd_shape(sd, 'decoder.conv_in.weight')[1] == 32 and len(sd_shape(sd, 'decoder.conv_in.weight')) == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) @@ -414,7 +409,7 @@ class VAE: self.downscale_ratio = 4 self.upscale_ratio = 4 - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.weight")[1] if 'decoder.post_quant_conv.weight' in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) @@ -427,7 +422,7 @@ class VAE: self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1]) else: self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, @@ -462,11 +457,11 @@ class VAE: self.downscale_index_formula = (6, 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv - tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"] + tensor_conv1_shape = sd_shape(sd, "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight") version = 0 - if tensor_conv1.shape[0] == 512: + if tensor_conv1_shape[0] == 512: version = 0 - elif tensor_conv1.shape[0] == 1024: + elif tensor_conv1_shape[0] == 1024: version = 1 if "encoder.down_blocks.1.conv.conv.bias" in sd: version = 2 @@ -483,9 +478,9 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: + elif "decoder.conv_in.conv.weight" in sd and sd_shape(sd, 'decoder.conv_in.conv.weight')[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} - ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] + ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1] self.latent_channels = 32 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) @@ -509,8 +504,8 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 - self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + self.latent_channels = ddconfig['z_channels'] = sd_shape(sd, "decoder.conv_in.conv.weight")[1] + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd_shape(sd, 'post_quant_conv.weight')[1]) #This is likely to significantly over-estimate with single image or low frame counts as the #implementation is able to completely skip caching. Rework if used as an image only VAE self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) @@ -543,14 +538,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE - dim = sd["decoder.head.0.gamma"].shape[0] + dim = sd_shape(sd, "decoder.head.0.gamma")[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.output_channels = sd_shape(sd, "encoder.conv1.weight")[1] self.pad_channel_value = 1.0 ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) @@ -626,7 +621,7 @@ class VAE: self.working_dtypes = [torch.float32] self.crop_input = False elif "decoder.22.bias" in sd: # taehv, taew and lighttae - self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_channels = sd_shape(sd, "decoder.1.weight")[1] self.latent_dim = 3 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) @@ -637,12 +632,12 @@ class VAE: self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) - elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + elif self.latent_channels == 32 and sd_shape(sd, "decoder.22.bias")[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: - if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + if comfy.utils.state_dict_meta(sd, "decoder.1.weight").dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical latent_format=comfy.latent_formats.HunyuanVideo else: latent_format=None # lighttaew2_1 doesn't need scaling @@ -662,7 +657,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) + m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False) if len(m) > 0: logging.warning("Missing VAE keys {}".format(m)) @@ -983,9 +978,12 @@ def load_style_model(ckpt_path): model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) - model.load_state_dict(model_data) + comfy.utils.load_state_dict(model, model_data, strict=True) return StyleModel(model) +def sd_shape(state_dict, key): + return comfy.utils.state_dict_meta(state_dict, key).shape + class CLIPType(Enum): STABLE_DIFFUSION = 1 STABLE_CASCADE = 2 @@ -1055,16 +1053,16 @@ def detect_te_model(sd): if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: - weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - if weight.shape[-1] == 4096: + weight_shape = sd_shape(sd, "encoder.block.23.layer.1.DenseReluDense.wi_1.weight") + if weight_shape[-1] == 4096: return TEModel.T5_XXL - elif weight.shape[-1] == 2048: + elif weight_shape[-1] == 2048: return TEModel.T5_XL if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: - weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight'] - if weight.shape[0] == 384: + weight_shape = sd_shape(sd, 'encoder.block.0.layer.0.SelfAttention.k.weight') + if weight_shape[0] == 384: return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: @@ -1074,19 +1072,19 @@ def detect_te_model(sd): return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: - weight = sd['model.layers.0.self_attn.k_proj.bias'] - if weight.shape[0] == 256: + weight_shape = sd_shape(sd, 'model.layers.0.self_attn.k_proj.bias') + if weight_shape[0] == 256: return TEModel.QWEN25_3B - if weight.shape[0] == 512: + if weight_shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: - weight = sd['model.layers.0.post_attention_layernorm.weight'] + weight_shape = sd_shape(sd, 'model.layers.0.post_attention_layernorm.weight') if 'model.layers.0.self_attn.q_norm.weight' in sd: - if weight.shape[0] == 2560: + if weight_shape[0] == 2560: return TEModel.QWEN3_4B - elif weight.shape[0] == 2048: + elif weight_shape[0] == 2048: return TEModel.QWEN3_2B - if weight.shape[0] == 5120: + if weight_shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B else: @@ -1415,19 +1413,29 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c scaled_fp8_list.append(k[:-len("scaled_fp8")]) if len(scaled_fp8_list) > 0: - out_sd = {} - for k in sd: - skip = False + if comfy.utils.is_stream_state_dict(sd): + def _keep_key(k, prefixes=tuple(scaled_fp8_list)): + return not any(k.startswith(pref) for pref in prefixes) + out_sd = comfy.safetensors_stream.FilterViewStateDict(sd, _keep_key, mutate_base=False) + merged = out_sd for pref in scaled_fp8_list: - skip = skip or k.startswith(pref) - if not skip: - out_sd[k] = sd[k] + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + merged = comfy.safetensors_stream.MergedStateDict(merged, quant_sd) + sd = merged + else: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] - for pref in scaled_fp8_list: - quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) - for k in quant_sd: - out_sd[k] = quant_sd[k] - sd = out_sd + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: @@ -1505,12 +1513,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - logging.warning("{} {}".format(diffusers_keys[k], k)) + if comfy.utils.is_stream_state_dict(sd): + new_sd = comfy.safetensors_stream.MappedStateDict(sd, diffusers_keys) + else: + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..9fb8a8f5c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return comfy.utils.load_state_dict(self.transformer, sd, strict=False) def parse_parentheses(string): result = [] @@ -430,8 +430,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No try: if embed_path.lower().endswith(".safetensors"): - import safetensors.torch - embed = safetensors.torch.load_file(embed_path, device="cpu") + embed = comfy.utils.load_torch_file(embed_path, safe_load=True) else: try: embed = torch.load(embed_path, weights_only=True, map_location="cpu") diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index ce36f1a84..7cf040588 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -56,9 +56,9 @@ class TAESD(nn.Module): self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) if encoder_path is not None: - self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) + comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True) if decoder_path is not None: - self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) + comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True) @staticmethod def scale_latents(x): diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 130ebaeae..c998061b4 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module): if len(sdo) == 0: sdo = sd - return self.load_state_dict(sdo, strict=False) + return comfy.utils.load_state_dict(self, sdo, strict=False) def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): diff --git a/comfy/utils.py b/comfy/utils.py index ffa98c9b1..7281dcd7e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -26,10 +26,13 @@ import numpy as np from PIL import Image import logging import itertools +from types import SimpleNamespace from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args import json +from . import safetensors_stream +import comfy.disk_weights MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -61,15 +64,9 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: - with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: - sd = {} - for k in f.keys(): - tensor = f.get_tensor(k) - if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues - tensor = tensor.to(device=device, copy=True) - sd[k] = tensor - if return_metadata: - metadata = f.metadata() + sd = safetensors_stream.StreamStateDict.from_file(ckpt, device=device) + if return_metadata: + metadata = sd.metadata() except Exception as e: if len(e.args) > 0: message = e.args[0] @@ -110,16 +107,16 @@ def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): - w = sd[k] - params += w.nelement() + meta = state_dict_meta(sd, k) + params += meta.numel return params def weight_dtype(sd, prefix=""): dtypes = {} for k in sd.keys(): if k.startswith(prefix): - w = sd[k] - dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel() + meta = state_dict_meta(sd, k) + dtypes[meta.dtype] = dtypes.get(meta.dtype, 0) + meta.numel if len(dtypes) == 0: return None @@ -133,6 +130,13 @@ def state_dict_key_replace(state_dict, keys_to_replace): return state_dict def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if is_stream_state_dict(state_dict): + return safetensors_stream.RenameViewStateDict( + state_dict, + replace_prefix, + filter_keys=filter_keys, + mutate_base=not filter_keys, + ) if filter_keys: out = {} else: @@ -145,6 +149,77 @@ def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): return out +def is_stream_state_dict(state_dict) -> bool: + return getattr(state_dict, "is_stream_state_dict", False) + + +def state_dict_meta(state_dict, key): + if hasattr(state_dict, "meta"): + return state_dict.meta(key) + w = state_dict[key] + numel = w.numel() + return SimpleNamespace( + dtype=w.dtype, + shape=tuple(w.shape), + numel=numel, + nbytes=numel * w.element_size(), + ) + + +def load_state_dict(model, state_dict, strict=False, assign=False): + if is_stream_state_dict(state_dict): + comfy.disk_weights.register_module_weights(model, state_dict) + comfy.disk_weights.attach_disk_weight_hooks(model) + missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign) + return missing, unexpected + return model.load_state_dict(state_dict, strict=strict) + + +def stream_load_state_dict(model, state_dict, strict=False, assign=False): + if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"): + state_dict = state_dict.copy() + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + + def load(module, local_state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata["assign_to_params_buffers"] = assign + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = f"{prefix}{name}." + child_state_dict = safetensors_stream.FilterViewStateDict( + local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False + ) + load(child, child_state_dict, child_prefix) + incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible) + if out is not None: + raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.") + + load(model, state_dict) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs))) + return missing_keys, unexpected_keys + + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", @@ -1217,46 +1292,82 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): scaled_fp8_key = "{}scaled_fp8".format(model_prefix) if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict[scaled_fp8_key] - scaled_fp8_dtype = scaled_fp8_weight.dtype + if is_stream_state_dict(state_dict): + scaled_meta = state_dict_meta(state_dict, scaled_fp8_key) + scaled_fp8_dtype = scaled_meta.dtype + scaled_fp8_weight_nelements = scaled_meta.numel + else: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + scaled_fp8_weight_nelements = scaled_fp8_weight.nelement() if scaled_fp8_dtype == torch.float32: scaled_fp8_dtype = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: + if scaled_fp8_weight_nelements == 2: full_precision_matrix_mult = True else: full_precision_matrix_mult = False - out_sd = {} layers = {} - for k in list(state_dict.keys()): - if k == scaled_fp8_key: - continue - if not k.startswith(model_prefix): - out_sd[k] = state_dict[k] - continue - k_out = k - w = state_dict.pop(k) - layer = None - if k_out.endswith(".scale_weight"): - layer = k_out[:-len(".scale_weight")] - k_out = "{}.weight_scale".format(layer) - - if layer is not None: - layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints - if full_precision_matrix_mult: - layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult - layers[layer] = layer_conf - - if k_out.endswith(".scale_input"): - layer = k_out[:-len(".scale_input")] - k_out = "{}.input_scale".format(layer) - if w.item() == 1.0: + if is_stream_state_dict(state_dict): + key_map = {} + for k in list(state_dict.keys()): + if k == scaled_fp8_key: continue + if not k.startswith(model_prefix): + key_map[k] = k + continue + k_out = k + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) - out_sd[k_out] = w + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf - state_dict = out_sd + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + scale_val = state_dict[k] + if scale_val.item() == 1.0: + continue + + key_map[k] = k_out + state_dict = safetensors_stream.MappedStateDict(state_dict, key_map) + else: + out_sd = {} + for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd quant_metadata = {"layers": layers} else: quant_metadata = json.loads(metadata["_quantization_metadata"]) diff --git a/nodes.py b/nodes.py index 56b74ebe3..bd4650276 100644 --- a/nodes.py +++ b/nodes.py @@ -17,7 +17,6 @@ from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np -import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -525,7 +524,7 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) - latent = safetensors.torch.load_file(latent_path, device="cpu") + latent = comfy.utils.load_torch_file(latent_path, safe_load=True) multiplier = 1.0 if "latent_format_version_0" not in latent: multiplier = 1.0 / 0.18215 diff --git a/requirements.txt b/requirements.txt index bc8346bcf..9d3e0aa53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ transformers>=4.50.3 tokenizers>=0.13.3 sentencepiece safetensors>=0.4.2 +fastsafetensors @ https://github.com/foundation-model-stack/fastsafetensors/archive/refs/heads/main.zip aiohttp>=3.11.8 yarl>=1.18.0 pyyaml diff --git a/tests-unit/utils/safetensors_stream_test.py b/tests-unit/utils/safetensors_stream_test.py new file mode 100644 index 000000000..e9bb72245 --- /dev/null +++ b/tests-unit/utils/safetensors_stream_test.py @@ -0,0 +1,146 @@ +import os + +import pytest +import importlib +import importlib.util + +torch = pytest.importorskip("torch") + + +def _write_safetensors(tmp_path, tensors): + import safetensors.torch + path = os.path.join(tmp_path, "test.safetensors") + safetensors.torch.save_file(tensors, path) + return path + + +def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + meta = sd.meta("a") + assert meta.shape == (2, 3) + assert meta.dtype == torch.float32 + assert meta.numel == 6 + assert calls == [] + + +def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + _ = sd["a"] + assert len(calls) == 1 + assert calls[0].shape == (2, 3) + + +def test_stream_views_do_not_materialize(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + calls = [] + + original = sd._file.read_tensor + + def wrapped(meta, device, dtype, allow_gds, pin_if_cpu): + calls.append(meta) + return original(meta, device, dtype, allow_gds, pin_if_cpu) + + monkeypatch.setattr(sd._file, "read_tensor", wrapped) + view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True) + _ = list(view.keys()) + assert calls == [] + + +def test_stream_load_rss_small(tmp_path): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + psutil = pytest.importorskip("psutil") + process = psutil.Process() + size_elems = 4_000_000 # ~16MB float32 + tensor = torch.zeros((size_elems,), dtype=torch.float32) + path = _write_safetensors(tmp_path, {"big": tensor}) + rss_before = process.memory_info().rss + sd = comfy.utils.load_torch_file(path, safe_load=True) + rss_after = process.memory_info().rss + expected_size = tensor.numel() * tensor.element_size() + assert (rss_after - rss_before) < expected_size + _ = sd.meta("big") + + +def test_gds_path_errors_without_support(tmp_path, monkeypatch): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + device = torch.device("cuda") + + if importlib.util.find_spec("fastsafetensors") is None: + fst = None + else: + fst = importlib.import_module("fastsafetensors") + + gds_available = False + if fst is not None and torch.cuda.is_available(): + gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device()) + gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1 + + if not gds_available: + with pytest.raises(RuntimeError, match="GPUDirect requested"): + sd.get_tensor("a", device=device, allow_gds=True) + else: + def fail_nogds(*args, **kwargs): + raise AssertionError("nogds path used during GDS request") + + monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds) + t = sd.get_tensor("a", device=device, allow_gds=True) + assert t.device.type == "cuda" + + +def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path): + if torch is None: + pytest.skip("torch not installed") + import comfy.utils + import comfy.disk_weights + + prev_cache = comfy.disk_weights.CACHE.max_bytes + prev_gds = comfy.disk_weights.ALLOW_GDS + prev_pin = comfy.disk_weights.PIN_IF_CPU + prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED + comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True) + + try: + path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + model = torch.nn.Linear(4, 4, bias=True) + comfy.utils.load_state_dict(model, sd, strict=False) + assert model.weight.device.type == "cpu" + assert model.weight.device.type != "meta" + finally: + comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)